Compare commits

..

37 Commits

Author SHA1 Message Date
Nicholas Tindle
3610be3e83 Merge branch 'dev' into ntindle/waitlist 2026-01-20 17:47:02 -06:00
Nicholas Tindle
9e1f7c9415 Merge branch 'dev' into ntindle/waitlist 2026-01-19 01:12:14 -06:00
Nicholas Tindle
0d03ebb43c fix: lint 2026-01-16 11:34:00 -06:00
Nicholas Tindle
1b37bd6da9 Merge branch 'dev' into ntindle/waitlist 2026-01-16 11:32:05 -06:00
Nicholas Tindle
db989a5eed fix: lint 2026-01-15 15:58:33 -06:00
Nicholas Tindle
e3a8c57a35 Merge branch 'dev' into ntindle/waitlist
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:54:38 -06:00
Nicholas Tindle
dfc8e53386 fix(backend): add assertions to fix type errors in waitlist admin functions
Prisma's update() returns T | None but we verify existence before updating,
so assert the result is not None to satisfy the type checker.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:48:30 -06:00
Nicholas Tindle
b5b7e5da92 fix(backend): don't mark waitlist DONE if email-only users pending
The notify_waitlist_users_on_launch function was marking waitlists as
DONE after notifying registered users, but ignoring unaffiliatedEmailUsers
who haven't been notified yet. Since DONE waitlists are excluded from
future notification queries, those email users would never receive
notifications when that functionality is implemented.

Now the waitlist remains in an active state if there are pending
email-only signups that still need notifications.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:41:31 -06:00
Nicholas Tindle
07ea2c2ab7 fix(backend): check waitlist existence before update in update_waitlist_admin
Added find_unique check before update() call to properly return 404 when
waitlist doesn't exist, following the established pattern used in other
waitlist admin functions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:37:19 -06:00
Nicholas Tindle
9c873a0158 fix(backend): add exception handling to add_self_to_waitlist route
The public waitlist join route was missing exception handling, causing
500 errors for all failures. Now properly returns:
- 404 for waitlist not found
- 400 for closed/unavailable waitlists
- 500 for unexpected errors

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:30:54 -06:00
Nicholas Tindle
ed634db8f7 fix(backend): validate waitlist status enum at API boundary
Changed WaitlistUpdateRequest.status from str to the actual enum type.
Pydantic now validates the status value, returning 422 for invalid
values instead of a misleading 404 "Waitlist not found" error.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:26:26 -06:00
Nicholas Tindle
398197f3ea fix(frontend): add title attribute to YouTube iframe for accessibility
Screen readers need a title attribute on iframes to describe their
content. Added "YouTube video player" title to the embedded video.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:23:52 -06:00
Nicholas Tindle
b7df4cfdbf fix(backend): align migration FK with schema (SET NULL not CASCADE)
The migration had ON DELETE CASCADE for WaitlistEntry.storeListingId,
but the Prisma schema specifies onDelete: SetNull. This mismatch would
cause waitlist entries and all signup data to be deleted when a store
listing is removed, instead of just unlinking them.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:18:03 -06:00
Nicholas Tindle
5d8dd46759 fix(backend): align waitlist admin functions with established patterns
- delete_waitlist_admin: add find_unique check before update, raise
  ValueError if not found, add except ValueError: raise
- link_waitlist_to_listing_admin: add find_unique check for waitlist
  before update, remove dead code
- delete_waitlist route: add except ValueError: → 404, remove dead
  code bool check pattern

All waitlist admin functions now follow the consistent pattern:
1. find_unique to check existence
2. raise ValueError if not found
3. except ValueError: raise to bubble up
4. except Exception: raise DatabaseError

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:54:53 -06:00
Nicholas Tindle
f9518b6f8b fix(frontend): use generated query key for waitlist cache invalidation
The hardcoded query key string didn't match the actual generated key,
causing cache invalidation to fail after joining a waitlist. Now uses
the generated getGetV2GetWaitlistIdsTheCurrentUserHasJoinedQueryKey()
function for correct cache invalidation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:44:35 -06:00
Nicholas Tindle
205b220e90 fix(backend): filter out DONE/CANCELED waitlists before sending notifications
The notify_waitlist_users_on_launch function was not filtering by
waitlist status, which could cause duplicate notifications when an
agent is re-approved. Now excludes DONE and CANCELED waitlists,
consistent with get_waitlist() and add_user_to_waitlist().

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:41:37 -06:00
Nicholas Tindle
29a232fcb4 fix(frontend): add URL validation and sandbox to video player
- Add getYouTubeVideoId() to extract video IDs from YouTube URLs
- Add isValidVideoUrl() to validate video URLs before rendering
- Create VideoPlayer component that:
  - Embeds YouTube videos via iframe with safe embed URL
  - Adds sandbox attribute to restrict iframe capabilities
  - Adds proper allow attributes for media playback
  - Falls back to native video element for valid non-YouTube URLs
  - Shows error state for invalid URLs

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:29:10 -06:00
Nicholas Tindle
a53f261812 feat(frontend): add TODO warning for email-only waitlist notifications
Adds a warning banner on the admin waitlist page indicating that
notifications for email-only signups (non-logged-in users) have not
been implemented yet.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:05:21 -06:00
Nicholas Tindle
00a20f77be feat(backend): add waitlist_launch email notification template
The WAITLIST_LAUNCH notification type was referencing a template that
didn't exist, causing FileNotFoundError when trying to notify users
that an agent they waitlisted has launched.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 16:38:04 -06:00
Nicholas Tindle
4d49536a40 Discard changes to autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts 2026-01-12 15:28:37 -07:00
Nicholas Tindle
6028a2528c refactor(frontend): consolidate waitlist modals and align with Figma design
- Merge JoinWaitlistModal into WaitlistDetailModal for unified experience
- Add MediaCarousel component supporting videos and images with play overlay
- Update WaitlistCard styling to match Figma (rounded-large, line-clamp-5, zinc-800 button)
- Update success state with party emoji and Close button per Figma design
- Add sticky footer for buttons during modal scroll
- Support email input for non-logged-in users

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 16:27:09 -06:00
Nicholas Tindle
b31cd05675 fix(backend): correct typo in unaffiliatedEmailUsers field name
- Rename unafilliatedEmailUsers -> unaffiliatedEmailUsers in schema.prisma
- Update migration SQL to use correct column name
- Update all references in db.py and model.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:33:38 -06:00
Nicholas Tindle
128366772f refactor(backend): remove apscheduler tables from prisma schema
- Remove apscheduler_jobs and apscheduler_jobs_batched_notifications models
- Delete migration 20260107000001_add_apscheduler_tables
- Remove index rename statements from waitlist migration

APScheduler tables are managed at runtime by APScheduler itself and
should not be part of the Prisma schema.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:29:17 -06:00
Nicholas Tindle
764cdf17fe refactor(frontend): migrate waitlist admin components to generated API hooks
- Convert WaitlistTable to use generated React Query hooks directly
- Convert CreateWaitlistButton to use generated hooks
- Update WaitlistDetailModal to use generated types and design system Dialog
- Remove deprecated waitlist types from types.ts
- Remove deprecated waitlist methods from BackendAPI client
- Delete actions.ts server actions (no longer needed)
- Replace lucide-react icons with Phosphor icons

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:26:34 -06:00
Nicholas Tindle
1dd83b4cf8 fix(frontend): add text color to status badge fallback in WaitlistTable
Ensures unknown status values have readable text contrast by adding
text-gray-700 to the fallback className.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:09:44 -06:00
Nicholas Tindle
24a34f7ce5 Merge branch 'dev' into ntindle/waitlist 2026-01-12 14:08:48 -07:00
Nicholas Tindle
20fe2c3877 fix(backend): remove PII-exposing fields from public waitlist model
Remove `owner` (User type) and `storeListing` (StoreListingWithVersions)
fields from StoreWaitlistEntry. These fields were never populated but
exposed PII types (email, stripe_customer_id, etc.) in the OpenAPI schema.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 14:52:51 -06:00
Nicholas Tindle
738c7e2bef fix(platform): address remaining PR review feedback for waitlist
Backend fixes:
- Fix optional field clearing by using model_fields_set
- Re-fetch waitlist data after join operation
- Only mark waitlist as DONE if all notifications succeed
- Fix race condition in email removal with transaction
- Rename waitlist_id to waitlistId for naming consistency

Frontend fixes:
- Migrate useWaitlistSection to generated API hooks
- Migrate JoinWaitlistModal to design system + generated hooks
- Migrate WaitlistSignupsDialog to design system + generated hooks
- Replace lucide-react icons with Phosphor in WaitlistTable
- Add proper error state in WaitlistSignupsDialog
- Update waitlistId naming across components

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 14:43:10 -06:00
Nicholas Tindle
9edfe0fb97 refactor(frontend): migrate EditWaitlistDialog to design system and generated API
- Replace legacy Dialog components with molecules/Dialog
- Replace legacy Input/Label/Textarea with atoms/Input
- Replace legacy Select with atoms/Select
- Replace @/lib/autogpt-server-api/types with @/app/api/__generated__/models
- Replace updateWaitlist action with usePutV2UpdateWaitlist hook
- Remove dependency on BackendAPI in favor of generated React Query hooks

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 16:49:35 -07:00
Nicholas Tindle
4aabe71001 fix(platform): address PR review feedback for waitlist feature
Backend fixes:
- Fix creator_username null check in store URL construction
- Add embed=True to link_waitlist_to_listing endpoint body param
- Fix race condition in email list with transaction wrapper
- Replace str(e) with generic error messages in admin ValueError handlers
- Add validation requiring user_id or email in waitlist join
- Configure WAITLIST_LAUNCH in notification system (data type, queue, template, subject)
- Change StoreListing cascade delete to SetNull to preserve waitlist data

Frontend fixes:
- Escape internal quotes in CSV export for proper RFC 4180 compliance
- Remove incorrect 'use server' directive from page.tsx
- Replace lucide-react Check icon with Phosphor Icons

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 16:40:35 -07:00
Nicholas Tindle
b3999669f2 refactor(platform): simplify waitlist code and remove type duplication
- Backend: Extract _waitlist_to_store_entry helper to reduce duplication
- Backend: Use dict comprehension in update_waitlist_admin for cleaner code
- Frontend: Import types directly from shared types file instead of re-exporting
- Frontend: Remove redundant isMember check in WaitlistCard handleJoinClick

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 16:25:27 -07:00
Swifty
8c45a5ee98 Merge branch 'dev' into ntindle/waitlist 2026-01-08 12:38:46 +01:00
Nicholas Tindle
4b654c7e9f fix(frontend): Fix lint and type errors in waitlist admin components
- Remove unused WaitlistSignup import
- Change button size from "sm" to "small"

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 22:48:53 -07:00
Nicholas Tindle
8d82e3b633 fix(backend): Use Prisma connect pattern for waitlist-listing relation
Use StoreListing relation with connect pattern instead of directly
setting storeListingId, which doesn't work with Prisma's typed update.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 22:01:18 -07:00
Nicholas Tindle
d4ecdb64ed feat(platform): Show "On the waitlist" status for joined users
- Add GET /api/store/waitlist/my-memberships endpoint to fetch user's joined waitlists
- Add get_user_waitlist_memberships() db function
- Update useWaitlistSection hook to fetch memberships when logged in
- Update WaitlistCard to show green "On the waitlist" button for members
- Update WaitlistDetailModal to show member status
- Add onSuccess callback to JoinWaitlistModal for optimistic UI updates

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 21:15:03 -07:00
Nicholas Tindle
a73fb8f114 feat(platform): Add waitlist feature with admin management and user notifications
Backend:
- Add waitlist admin API routes for CRUD operations
- Add admin functions for waitlist management (create, update, delete, list)
- Add WaitlistLaunchData notification type for user notifications
- Integrate waitlist notifications into store submission approval flow
- Auto-notify waitlist users when linked agent is approved

Frontend:
- Add admin waitlist management page with table, create/edit dialogs
- Add WaitlistSection component to marketplace homepage
- Add WaitlistCard, WaitlistDetailModal, JoinWaitlistModal components
- Add API client methods and types for waitlist operations

Database:
- Add WAITLIST_LAUNCH notification type enum
- Add baseline migration for APScheduler tables

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 20:38:15 -07:00
Nicholas Tindle
2c60aa64ef wip: adding waitlist 2026-01-06 22:13:35 -07:00
458 changed files with 9514 additions and 16327 deletions

View File

@@ -128,7 +128,7 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
exitOnceUploaded: true exitOnceUploaded: true
e2e_test: test:
runs-on: big-boi runs-on: big-boi
needs: setup needs: setup
strategy: strategy:
@@ -258,39 +258,3 @@ jobs:
- name: Print Final Docker Compose logs - name: Print Final Docker Compose logs
if: always() if: always()
run: docker compose -f ../docker-compose.yml logs run: docker compose -f ../docker-compose.yml logs
integration_test:
runs-on: ubuntu-latest
needs: setup
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Generate API client
run: pnpm generate:api
- name: Run Integration Tests
run: pnpm test:unit

View File

@@ -16,32 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
- Format Python code with `poetry run format`. - Format Python code with `poetry run format`.
- Format frontend code using `pnpm format`. - Format frontend code using `pnpm format`.
## Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
- Add `usePageName.ts` hook for logic
- Put sub-components in local `components/` folder
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
## Testing ## Testing
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma). - Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).

View File

@@ -201,7 +201,7 @@ If you get any pushback or hit complex block conditions check the new_blocks gui
3. Write tests alongside the route file 3. Write tests alongside the route file
4. Run `poetry run test` to verify 4. Run `poetry run test` to verify
### Frontend guidelines: **Frontend feature development:**
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
@@ -217,14 +217,6 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only 4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E 5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers 6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
### Security Implementation ### Security Implementation

View File

@@ -178,10 +178,5 @@ AYRSHARE_JWT_KEY=
SMARTLEAD_API_KEY= SMARTLEAD_API_KEY=
ZEROBOUNCE_API_KEY= ZEROBOUNCE_API_KEY=
# PostHog Analytics
# Get API key from https://posthog.com - Project Settings > Project API Key
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Other Services # Other Services
AUTOMOD_API_KEY= AUTOMOD_API_KEY=

View File

@@ -86,8 +86,6 @@ async def execute_graph_block(
obj = backend.data.block.get_block(block_id) obj = backend.data.block.get_block(block_id)
if not obj: if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled:
raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.")
output = defaultdict(list) output = defaultdict(list)
async for name, data in obj.execute(data): async for name, data in obj.execute(data):

View File

@@ -0,0 +1,251 @@
import logging
import autogpt_libs.auth
import fastapi
import fastapi.responses
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
prefix="/admin/waitlist",
tags=["store", "admin", "waitlist"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
@router.post(
"",
summary="Create Waitlist",
response_model=store_model.WaitlistAdminResponse,
)
async def create_waitlist(
request: store_model.WaitlistCreateRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new waitlist (admin only).
Args:
request: Waitlist creation details
user_id: Authenticated admin user creating the waitlist
Returns:
WaitlistAdminResponse with the created waitlist details
"""
try:
waitlist = await store_db.create_waitlist_admin(
admin_user_id=user_id,
data=request,
)
return waitlist
except Exception as e:
logger.exception("Error creating waitlist: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while creating the waitlist"},
)
@router.get(
"",
summary="List All Waitlists",
response_model=store_model.WaitlistAdminListResponse,
)
async def list_waitlists():
"""
Get all waitlists with admin details (admin only).
Returns:
WaitlistAdminListResponse with all waitlists
"""
try:
return await store_db.get_waitlists_admin()
except Exception as e:
logger.exception("Error listing waitlists: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while fetching waitlists"},
)
@router.get(
"/{waitlist_id}",
summary="Get Waitlist Details",
response_model=store_model.WaitlistAdminResponse,
)
async def get_waitlist(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Get a single waitlist with admin details (admin only).
Args:
waitlist_id: ID of the waitlist to retrieve
Returns:
WaitlistAdminResponse with waitlist details
"""
try:
return await store_db.get_waitlist_admin(waitlist_id)
except ValueError:
logger.warning("Waitlist not found: %s", waitlist_id)
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Waitlist not found"},
)
except Exception as e:
logger.exception("Error fetching waitlist: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while fetching the waitlist"},
)
@router.put(
"/{waitlist_id}",
summary="Update Waitlist",
response_model=store_model.WaitlistAdminResponse,
)
async def update_waitlist(
request: store_model.WaitlistUpdateRequest,
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Update a waitlist (admin only).
Args:
waitlist_id: ID of the waitlist to update
request: Fields to update
Returns:
WaitlistAdminResponse with updated waitlist details
"""
try:
return await store_db.update_waitlist_admin(waitlist_id, request)
except ValueError:
logger.warning("Waitlist not found for update: %s", waitlist_id)
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Waitlist not found"},
)
except Exception as e:
logger.exception("Error updating waitlist: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while updating the waitlist"},
)
@router.delete(
"/{waitlist_id}",
summary="Delete Waitlist",
)
async def delete_waitlist(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Soft delete a waitlist (admin only).
Args:
waitlist_id: ID of the waitlist to delete
Returns:
Success message
"""
try:
await store_db.delete_waitlist_admin(waitlist_id)
return {"message": "Waitlist deleted successfully"}
except ValueError:
logger.warning(f"Waitlist not found for deletion: {waitlist_id}")
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Waitlist not found"},
)
except Exception as e:
logger.exception("Error deleting waitlist: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while deleting the waitlist"},
)
@router.get(
"/{waitlist_id}/signups",
summary="Get Waitlist Signups",
response_model=store_model.WaitlistSignupListResponse,
)
async def get_waitlist_signups(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Get all signups for a waitlist (admin only).
Args:
waitlist_id: ID of the waitlist
Returns:
WaitlistSignupListResponse with all signups
"""
try:
return await store_db.get_waitlist_signups_admin(waitlist_id)
except ValueError:
logger.warning("Waitlist not found for signups: %s", waitlist_id)
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Waitlist not found"},
)
except Exception as e:
logger.exception("Error fetching waitlist signups: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while fetching waitlist signups"},
)
@router.post(
"/{waitlist_id}/link",
summary="Link Waitlist to Store Listing",
response_model=store_model.WaitlistAdminResponse,
)
async def link_waitlist_to_listing(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
store_listing_id: str = fastapi.Body(
..., embed=True, description="The ID of the store listing"
),
):
"""
Link a waitlist to a store listing (admin only).
When the linked store listing is approved/published, waitlist users
will be automatically notified.
Args:
waitlist_id: ID of the waitlist
store_listing_id: ID of the store listing to link
Returns:
WaitlistAdminResponse with updated waitlist details
"""
try:
return await store_db.link_waitlist_to_listing_admin(
waitlist_id, store_listing_id
)
except ValueError:
logger.warning(
"Link failed - waitlist or listing not found: %s, %s",
waitlist_id,
store_listing_id,
)
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Waitlist or store listing not found"},
)
except Exception as e:
logger.exception("Error linking waitlist to listing: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while linking the waitlist"},
)

View File

@@ -290,11 +290,6 @@ async def _cache_session(session: ChatSession) -> None:
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json()) await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session without persisting to the database."""
await _cache_session(session)
async def _get_session_from_db(session_id: str) -> ChatSession | None: async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database.""" """Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id) prisma_session = await chat_db.get_chat_session(session_id)

View File

@@ -31,7 +31,6 @@ class ResponseType(str, Enum):
# Other # Other
ERROR = "error" ERROR = "error"
USAGE = "usage" USAGE = "usage"
HEARTBEAT = "heartbeat"
class StreamBaseResponse(BaseModel): class StreamBaseResponse(BaseModel):
@@ -143,20 +142,3 @@ class StreamError(StreamBaseResponse):
details: dict[str, Any] | None = Field( details: dict[str, Any] | None = Field(
default=None, description="Additional error details" default=None, description="Additional error details"
) )
class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations.
Uses SSE comment format (: comment) which is ignored by clients but keeps
the connection alive through proxies and load balancers.
"""
type: ResponseType = ResponseType.HEARTBEAT
toolCallId: str | None = Field(
default=None, description="Tool call ID if heartbeat is for a specific tool"
)
def to_sse(self) -> str:
"""Convert to SSE comment format to keep connection alive."""
return ": heartbeat\n\n"

View File

@@ -172,12 +172,12 @@ async def get_session(
user_id: The optional authenticated user ID, or None for anonymous access. user_id: The optional authenticated user ID, or None for anonymous access.
Returns: Returns:
SessionDetailResponse: Details for the requested session, or None if not found. SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
""" """
session = await get_chat_session(session_id, user_id) session = await get_chat_session(session_id, user_id)
if not session: if not session:
raise NotFoundError(f"Session {session_id} not found.") raise NotFoundError(f"Session {session_id} not found")
messages = [message.model_dump() for message in session.messages] messages = [message.model_dump() for message in session.messages]
logger.info( logger.info(
@@ -222,8 +222,6 @@ async def stream_chat_post(
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion( async for chunk in chat_service.stream_chat_completion(
session_id, session_id,
request.message, request.message,
@@ -232,26 +230,7 @@ async def stream_chat_post(
session=session, # Pass pre-fetched session to avoid double-fetch session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context, context=request.context,
): ):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse() yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination # AI SDK protocol termination
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@@ -296,8 +275,6 @@ async def stream_chat_get(
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion( async for chunk in chat_service.stream_chat_completion(
session_id, session_id,
message, message,
@@ -305,26 +282,7 @@ async def stream_chat_get(
user_id=user_id, user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch session=session, # Pass pre-fetched session to avoid double-fetch
): ):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse() yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination # AI SDK protocol termination
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,8 @@
import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool from .agent_output import AgentOutputTool
@@ -22,8 +20,6 @@ from .search_docs import SearchDocsTool
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable from backend.api.features.chat.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
# Single source of truth for all tools # Single source of truth for all tools
TOOL_REGISTRY: dict[str, BaseTool] = { TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(), "add_understanding": AddUnderstandingTool(),
@@ -60,17 +56,4 @@ async def execute_tool(
tool = TOOL_REGISTRY.get(tool_name) tool = TOOL_REGISTRY.get(tool_name)
if not tool: if not tool:
raise ValueError(f"Tool {tool_name} not found") raise ValueError(f"Tool {tool_name} not found")
# Track tool call in PostHog
logger.info(
f"Tracking tool call: tool={tool_name}, user={user_id}, "
f"session={session.session_id}, call_id={tool_call_id}"
)
track_tool_called(
user_id=user_id,
session_id=session.session_id,
tool_name=tool_name,
tool_call_id=tool_call_id,
)
return await tool.execute(user_id, session, tool_call_id, **parameters) return await tool.execute(user_id, session, tool_call_id, **parameters)

View File

@@ -3,6 +3,8 @@
import logging import logging
from typing import Any from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.understanding import ( from backend.data.understanding import (
BusinessUnderstandingInput, BusinessUnderstandingInput,
@@ -59,6 +61,7 @@ and automations for the user's specific needs."""
"""Requires authentication to store user-specific data.""" """Requires authentication to store user-specific data."""
return True return True
@observe(as_type="tool", name="add_understanding")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,

View File

@@ -1,28 +1,29 @@
"""Agent generator package - Creates agents from natural language.""" """Agent generator package - Creates agents from natural language."""
from .core import ( from .core import (
AgentGeneratorNotConfiguredError, apply_agent_patch,
decompose_goal, decompose_goal,
generate_agent, generate_agent,
generate_agent_patch, generate_agent_patch,
get_agent_as_json, get_agent_as_json,
json_to_graph,
save_agent_to_library, save_agent_to_library,
) )
from .service import health_check as check_external_service_health from .fixer import apply_all_fixes
from .service import is_external_service_configured from .utils import get_blocks_info
from .validator import validate_agent
__all__ = [ __all__ = [
# Core functions # Core functions
"decompose_goal", "decompose_goal",
"generate_agent", "generate_agent",
"generate_agent_patch", "generate_agent_patch",
"apply_agent_patch",
"save_agent_to_library", "save_agent_to_library",
"get_agent_as_json", "get_agent_as_json",
"json_to_graph", # Fixer
# Exceptions "apply_all_fixes",
"AgentGeneratorNotConfiguredError", # Validator
# Service "validate_agent",
"is_external_service_configured", # Utils
"check_external_service_health", "get_blocks_info",
] ]

View File

@@ -0,0 +1,25 @@
"""OpenRouter client configuration for agent generation."""
import os
from openai import AsyncOpenAI
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
# OpenRouter client (OpenAI-compatible API)
_client: AsyncOpenAI | None = None
def get_client() -> AsyncOpenAI:
"""Get or create the OpenRouter client."""
global _client
if _client is None:
if not OPENROUTER_API_KEY:
raise ValueError("OPENROUTER_API_KEY environment variable is required")
_client = AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=OPENROUTER_API_KEY,
)
return _client

View File

@@ -1,5 +1,7 @@
"""Core agent generation functions.""" """Core agent generation functions."""
import copy
import json
import logging import logging
import uuid import uuid
from typing import Any from typing import Any
@@ -7,35 +9,13 @@ from typing import Any
from backend.api.features.library import db as library_db from backend.api.features.library import db as library_db
from backend.data.graph import Graph, Link, Node, create_graph from backend.data.graph import Graph, Link, Node, create_graph
from .service import ( from .client import AGENT_GENERATOR_MODEL, get_client
decompose_goal_external, from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
generate_agent_external, from .utils import get_block_summaries, parse_json_from_llm
generate_agent_patch_external,
is_external_service_configured,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None: async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
"""Break down a goal into steps or return clarifying questions. """Break down a goal into steps or return clarifying questions.
@@ -48,13 +28,40 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
- {"type": "clarifying_questions", "questions": [...]} - {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]} - {"type": "instructions", "steps": [...]}
Or None on error Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
""" """
_check_service_configured() client = get_client()
logger.info("Calling external Agent Generator service for decompose_goal") prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
return await decompose_goal_external(description, context)
full_description = description
if context:
full_description = f"{description}\n\nAdditional context:\n{context}"
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": full_description},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for decomposition")
return None
result = parse_json_from_llm(content)
if result is None:
logger.error(f"Failed to parse decomposition response: {content[:200]}")
return None
return result
except Exception as e:
logger.error(f"Error decomposing goal: {e}")
return None
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
@@ -65,14 +72,31 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
Returns: Returns:
Agent JSON dict or None on error Agent JSON dict or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
""" """
_check_service_configured() client = get_client()
logger.info("Calling external Agent Generator service for generate_agent") prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
result = await generate_agent_external(instructions)
if result: try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": json.dumps(instructions, indent=2)},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for agent generation")
return None
result = parse_json_from_llm(content)
if result is None:
logger.error(f"Failed to parse agent JSON: {content[:200]}")
return None
# Ensure required fields # Ensure required fields
if "id" not in result: if "id" not in result:
result["id"] = str(uuid.uuid4()) result["id"] = str(uuid.uuid4())
@@ -80,7 +104,12 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
result["version"] = 1 result["version"] = 1
if "is_active" not in result: if "is_active" not in result:
result["is_active"] = True result["is_active"] = True
return result
return result
except Exception as e:
logger.error(f"Error generating agent: {e}")
return None
def json_to_graph(agent_json: dict[str, Any]) -> Graph: def json_to_graph(agent_json: dict[str, Any]) -> Graph:
@@ -189,7 +218,6 @@ async def save_agent_to_library(
library_agents = await library_db.create_library_agent( library_agents = await library_db.create_library_agent(
graph=created_graph, graph=created_graph,
user_id=user_id, user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False, create_library_agents_for_sub_graphs=False,
) )
@@ -255,23 +283,108 @@ async def get_agent_as_json(
async def generate_agent_patch( async def generate_agent_patch(
update_request: str, current_agent: dict[str, Any] update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Update an existing agent using natural language. """Generate a patch to update an existing agent.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args: Args:
update_request: Natural language description of changes update_request: Natural language description of changes
current_agent: Current agent JSON current_agent: Current agent JSON
Returns: Returns:
Updated agent JSON, clarifying questions dict, or None on error Patch dict or clarifying questions, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
""" """
_check_service_configured() client = get_client()
logger.info("Calling external Agent Generator service for generate_agent_patch") prompt = PATCH_PROMPT.format(
return await generate_agent_patch_external(update_request, current_agent) current_agent=json.dumps(current_agent, indent=2),
block_summaries=get_block_summaries(),
)
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": update_request},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for patch generation")
return None
return parse_json_from_llm(content)
except Exception as e:
logger.error(f"Error generating patch: {e}")
return None
def apply_agent_patch(
current_agent: dict[str, Any], patch: dict[str, Any]
) -> dict[str, Any]:
"""Apply a patch to an existing agent.
Args:
current_agent: Current agent JSON
patch: Patch dict with operations
Returns:
Updated agent JSON
"""
agent = copy.deepcopy(current_agent)
patches = patch.get("patches", [])
for p in patches:
patch_type = p.get("type")
if patch_type == "modify":
node_id = p.get("node_id")
changes = p.get("changes", {})
for node in agent.get("nodes", []):
if node["id"] == node_id:
_deep_update(node, changes)
logger.debug(f"Modified node {node_id}")
break
elif patch_type == "add":
new_nodes = p.get("new_nodes", [])
new_links = p.get("new_links", [])
agent["nodes"] = agent.get("nodes", []) + new_nodes
agent["links"] = agent.get("links", []) + new_links
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
elif patch_type == "remove":
node_ids_to_remove = set(p.get("node_ids", []))
link_ids_to_remove = set(p.get("link_ids", []))
# Remove nodes
agent["nodes"] = [
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
]
# Remove links (both explicit and those referencing removed nodes)
agent["links"] = [
link
for link in agent.get("links", [])
if link["id"] not in link_ids_to_remove
and link["source_id"] not in node_ids_to_remove
and link["sink_id"] not in node_ids_to_remove
]
logger.debug(
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
)
return agent
def _deep_update(target: dict, source: dict) -> None:
"""Recursively update a dict with another dict."""
for key, value in source.items():
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
_deep_update(target[key], value)
else:
target[key] = value

View File

@@ -0,0 +1,606 @@
"""Agent fixer - Fixes common LLM generation errors."""
import logging
import re
import uuid
from typing import Any
from .utils import (
ADDTODICTIONARY_BLOCK_ID,
ADDTOLIST_BLOCK_ID,
CODE_EXECUTION_BLOCK_ID,
CONDITION_BLOCK_ID,
CREATEDICT_BLOCK_ID,
CREATELIST_BLOCK_ID,
DATA_SAMPLING_BLOCK_ID,
DOUBLE_CURLY_BRACES_BLOCK_IDS,
GET_CURRENT_DATE_BLOCK_ID,
STORE_VALUE_BLOCK_ID,
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
get_blocks_info,
is_valid_uuid,
)
logger = logging.getLogger(__name__)
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix invalid UUIDs in agent and link IDs."""
# Fix agent ID
if not is_valid_uuid(agent.get("id", "")):
agent["id"] = str(uuid.uuid4())
logger.debug(f"Fixed agent ID: {agent['id']}")
# Fix node IDs
id_mapping = {} # Old ID -> New ID
for node in agent.get("nodes", []):
if not is_valid_uuid(node.get("id", "")):
old_id = node.get("id", "")
new_id = str(uuid.uuid4())
id_mapping[old_id] = new_id
node["id"] = new_id
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
# Fix link IDs and update references
for link in agent.get("links", []):
if not is_valid_uuid(link.get("id", "")):
link["id"] = str(uuid.uuid4())
logger.debug(f"Fixed link ID: {link['id']}")
# Update source/sink IDs if they were remapped
if link.get("source_id") in id_mapping:
link["source_id"] = id_mapping[link["source_id"]]
if link.get("sink_id") in id_mapping:
link["sink_id"] = id_mapping[link["sink_id"]]
return agent
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix single curly braces to double in template blocks."""
for node in agent.get("nodes", []):
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
continue
input_data = node.get("input_default", {})
for key in ("prompt", "format"):
if key in input_data and isinstance(input_data[key], str):
original = input_data[key]
# Fix simple variable references: {var} -> {{var}}
fixed = re.sub(
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
r"{{\1}}",
original,
)
if fixed != original:
input_data[key] = fixed
logger.debug(f"Fixed curly braces in {key}")
return agent
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
# Find all ConditionBlock nodes
condition_node_ids = {
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
}
if not condition_node_ids:
return agent
new_nodes = []
new_links = []
processed_conditions = set()
for link in links:
sink_id = link.get("sink_id")
sink_name = link.get("sink_name")
# Check if this link goes to a ConditionBlock's value2
if sink_id in condition_node_ids and sink_name == "value2":
source_node = next(
(n for n in nodes if n["id"] == link.get("source_id")), None
)
# Skip if source is already a StoreValueBlock
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
continue
# Skip if we already processed this condition
if sink_id in processed_conditions:
continue
processed_conditions.add(sink_id)
# Create StoreValueBlock
store_node_id = str(uuid.uuid4())
store_node = {
"id": store_node_id,
"block_id": STORE_VALUE_BLOCK_ID,
"input_default": {"data": None},
"metadata": {"position": {"x": 0, "y": -100}},
}
new_nodes.append(store_node)
# Create link: original source -> StoreValueBlock
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": store_node_id,
"sink_name": "input",
"is_static": False,
}
)
# Update original link: StoreValueBlock -> ConditionBlock
link["source_id"] = store_node_id
link["source_name"] = "output"
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
if new_nodes:
agent["nodes"] = nodes + new_nodes
return agent
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
When an AddToList block is found:
1. Checks if there's a CreateListBlock before it
2. Removes CreateListBlock if linked directly to AddToList
3. Adds an empty AddToList block before the original
4. Ensures the original has a self-referencing link
"""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
new_nodes = []
original_addtolist_ids = set()
nodes_to_remove = set()
links_to_remove = []
# First pass: identify CreateListBlock nodes to remove
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATELIST_BLOCK_ID
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
# Second pass: process AddToList blocks
filtered_nodes = []
for node in nodes:
if node.get("id") in nodes_to_remove:
continue
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
original_addtolist_ids.add(node.get("id"))
node_id = node.get("id")
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
# Check if already has prerequisite
has_prereq = any(
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_name") == "updated_list"
for link in links
)
if not has_prereq:
# Remove links to "list" input (except self-reference)
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_id") != node_id
and link not in links_to_remove
):
links_to_remove.append(link)
# Create prerequisite AddToList block
prereq_id = str(uuid.uuid4())
prereq_node = {
"id": prereq_id,
"block_id": ADDTOLIST_BLOCK_ID,
"input_default": {"list": [], "entry": None, "entries": []},
"metadata": {
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
},
}
new_nodes.append(prereq_node)
# Link prerequisite to original
links.append(
{
"id": str(uuid.uuid4()),
"source_id": prereq_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added prerequisite AddToList block for {node_id}")
filtered_nodes.append(node)
# Remove marked links
filtered_links = [link for link in links if link not in links_to_remove]
# Add self-referencing links for original AddToList blocks
for node in filtered_nodes + new_nodes:
if (
node.get("block_id") == ADDTOLIST_BLOCK_ID
and node.get("id") in original_addtolist_ids
):
node_id = node.get("id")
has_self_ref = any(
link["source_id"] == node_id
and link["sink_id"] == node_id
and link["source_name"] == "updated_list"
and link["sink_name"] == "list"
for link in filtered_links
)
if not has_self_ref:
filtered_links.append(
{
"id": str(uuid.uuid4()),
"source_id": node_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added self-reference for AddToList {node_id}")
agent["nodes"] = filtered_nodes + new_nodes
agent["links"] = filtered_links
return agent
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
nodes_to_remove = set()
links_to_remove = []
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
if (
source_node
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
and link.get("source_name") == "response"
):
link["source_name"] = "stdout_logs"
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
return agent
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
links_to_remove = []
for node in nodes:
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
node_id = node.get("id")
input_default = node.get("input_default", {})
# Remove links to sample_size
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "sample_size"
):
links_to_remove.append(link)
# Set default
input_default["sample_size"] = 1
node["input_default"] = input_default
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
if links_to_remove:
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
node_lookup = {n.get("id"): n for n in nodes}
for link in links:
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_node = node_lookup.get(source_id)
sink_node = node_lookup.get(sink_id)
if not source_node or not sink_node:
continue
source_pos = source_node.get("metadata", {}).get("position", {})
sink_pos = sink_node.get("metadata", {}).get("position", {})
source_x = source_pos.get("x", 0)
sink_x = sink_pos.get("x", 0)
if abs(sink_x - source_x) < 800:
new_x = source_x + 800
if "metadata" not in sink_node:
sink_node["metadata"] = {}
if "position" not in sink_node["metadata"]:
sink_node["metadata"]["position"] = {}
sink_node["metadata"]["position"]["x"] = new_x
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
return agent
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
for node in agent.get("nodes", []):
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
input_default = node.get("input_default", {})
if "offset" in input_default:
offset = input_default["offset"]
if isinstance(offset, (int, float)) and offset < 0:
input_default["offset"] = abs(offset)
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
return agent
def fix_ai_model_parameter(
agent: dict[str, Any],
blocks_info: list[dict[str, Any]],
default_model: str = "gpt-4o",
) -> dict[str, Any]:
"""Add default model parameter to AI blocks if missing."""
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
# Check if block has AI category
categories = block.get("categories", [])
is_ai_block = any(
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
)
if is_ai_block:
input_default = node.get("input_default", {})
if "model" not in input_default:
input_default["model"] = default_model
node["input_default"] = input_default
logger.debug(
f"Added model '{default_model}' to AI block {node.get('id')}"
)
return agent
def fix_link_static_properties(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix is_static property based on source block's staticOutput."""
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
if not source_node:
continue
source_block = block_map.get(source_node.get("block_id"))
if not source_block:
continue
static_output = source_block.get("staticOutput", False)
if link.get("is_static") != static_output:
link["is_static"] = static_output
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
return agent
def fix_data_type_mismatch(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in nodes}
def get_property_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
type_mapping = {
"string": "string",
"text": "string",
"integer": "number",
"number": "number",
"float": "number",
"boolean": "boolean",
"bool": "boolean",
"array": "list",
"list": "list",
"object": "dictionary",
"dict": "dictionary",
"dictionary": "dictionary",
}
new_links = []
nodes_to_add = []
for link in links:
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
new_links.append(link)
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
new_links.append(link)
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_property_type(source_outputs, link.get("source_name", ""))
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
if (
source_type
and sink_type
and not are_types_compatible(source_type, sink_type)
):
# Insert type converter
converter_id = str(uuid.uuid4())
target_type = type_mapping.get(sink_type, sink_type)
converter_node = {
"id": converter_id,
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
"input_default": {"type": target_type},
"metadata": {"position": {"x": 0, "y": 100}},
}
nodes_to_add.append(converter_node)
# source -> converter
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": converter_id,
"sink_name": "value",
"is_static": False,
}
)
# converter -> sink
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": converter_id,
"source_name": "value",
"sink_id": link["sink_id"],
"sink_name": link["sink_name"],
"is_static": False,
}
)
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
else:
new_links.append(link)
if nodes_to_add:
agent["nodes"] = nodes + nodes_to_add
agent["links"] = new_links
return agent
def apply_all_fixes(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""Apply all fixes to an agent JSON.
Args:
agent: Agent JSON dict
blocks_info: Optional list of block info dicts for advanced fixes
Returns:
Fixed agent JSON
"""
# Basic fixes (no block info needed)
agent = fix_agent_ids(agent)
agent = fix_double_curly_braces(agent)
agent = fix_storevalue_before_condition(agent)
agent = fix_addtolist_blocks(agent)
agent = fix_addtodictionary_blocks(agent)
agent = fix_code_execution_output(agent)
agent = fix_data_sampling_sample_size(agent)
agent = fix_node_x_coordinates(agent)
agent = fix_getcurrentdate_offset(agent)
# Advanced fixes (require block info)
if blocks_info is None:
blocks_info = get_blocks_info()
agent = fix_ai_model_parameter(agent, blocks_info)
agent = fix_link_static_properties(agent, blocks_info)
agent = fix_data_type_mismatch(agent, blocks_info)
return agent

View File

@@ -0,0 +1,225 @@
"""Prompt templates for agent generation."""
DECOMPOSITION_PROMPT = """
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
---
FIRST: Analyze the user's goal and determine:
1) Design-time configuration (fixed settings that won't change per run)
2) Runtime inputs (values the agent's end-user will provide each time it runs)
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
- DO NOT ask for the actual value
- Instead, define it as an Agent Input with a clear name, type, and description
Only ask clarifying questions about design-time config that affects how you build the workflow:
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
- Business rules that must be hard-coded
IMPORTANT CLARIFICATIONS POLICY:
- Ask no more than five essential questions
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
- Do not ask for API keys or credentials; the platform handles those directly
- If there is enough information to infer reasonable defaults, prefer to propose defaults
---
GUIDELINES:
1. List each step as a numbered item
2. Describe the action clearly and specify inputs/outputs
3. Ensure steps are in logical, sequential order
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
5. Help the user reach their goal efficiently
---
RULES:
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
2. USE ONLY THE BLOCKS PROVIDED
3. ALL required_input fields must be provided
4. Data types of linked properties must match
5. Write expert-level prompts for AI-related blocks
---
CRITICAL BLOCK RESTRICTIONS:
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
3. ConditionBlock: value2 is reference, value1 is contrast
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
---
OUTPUT FORMAT:
If more information is needed:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
"keyword": "email_provider",
"example": "Gmail"
}}
]
}}
```
If ready to proceed:
```json
{{
"type": "instructions",
"steps": [
{{
"step_number": 1,
"block_name": "AgentShortTextInputBlock",
"description": "Get the URL of the content to analyze.",
"inputs": [{{"name": "name", "value": "URL"}}],
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
}}
]
}}
```
---
AVAILABLE BLOCKS:
{block_summaries}
"""
GENERATION_PROMPT = """
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
---
NODES:
Each node must include:
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
- `block_id`: The block identifier (must match an Allowed Block)
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
- `metadata`: Must contain:
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
- `customized_name`: Clear name describing this block's purpose in the workflow
---
LINKS:
Each link connects a source node's output to a sink node's input:
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
- `source_id`: ID of the source node
- `source_name`: Output field name from the source block
- `sink_id`: ID of the sink node
- `sink_name`: Input field name on the sink block
- `is_static`: true only if source block has static_output: true
CRITICAL: All IDs must be valid UUID v4 format!
---
AGENT (GRAPH):
Wrap nodes and links in:
- `id`: UUID of the agent
- `name`: Short, generic name (avoid specific company names, URLs)
- `description`: Short, generic description
- `nodes`: List of all nodes
- `links`: List of all links
- `version`: 1
- `is_active`: true
---
TIPS:
- All required_input fields must be provided via input_default or a valid link
- Ensure consistent source_id and sink_id references
- Avoid dangling links
- Input/output pins must match block schemas
- Do not invent unknown block_ids
---
ALLOWED BLOCKS:
{block_summaries}
---
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
"""
PATCH_PROMPT = """
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
CURRENT AGENT:
{current_agent}
AVAILABLE BLOCKS:
{block_summaries}
---
PATCH FORMAT:
Return a JSON object with the following structure:
```json
{{
"type": "patch",
"intent": "Brief description of what the patch does",
"patches": [
{{
"type": "modify",
"node_id": "uuid-of-node-to-modify",
"changes": {{
"input_default": {{"field": "new_value"}},
"metadata": {{"customized_name": "New Name"}}
}}
}},
{{
"type": "add",
"new_nodes": [
{{
"id": "new-uuid",
"block_id": "block-uuid",
"input_default": {{}},
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
}}
],
"new_links": [
{{
"id": "link-uuid",
"source_id": "source-node-id",
"source_name": "output_field",
"sink_id": "sink-node-id",
"sink_name": "input_field"
}}
]
}},
{{
"type": "remove",
"node_ids": ["uuid-of-node-to-remove"],
"link_ids": ["uuid-of-link-to-remove"]
}}
]
}}
```
If you need more information, return:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "What specific change do you want?",
"keyword": "change_type",
"example": "Add error handling"
}}
]
}}
```
Generate the minimal patch needed. Output ONLY valid JSON.
"""

View File

@@ -1,269 +0,0 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str, context: str = ""
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
Or None on error
"""
client = _get_client()
# Build the request payload
payload: dict[str, Any] = {"description": description}
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return None
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def generate_agent_external(
instructions: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
Returns:
Agent JSON dict or None on error
"""
client = _get_client()
try:
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
return data.get("agent_json")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def generate_agent_patch_external(
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
Returns:
Updated agent JSON, clarifying questions dict, or None on error
"""
client = _get_client()
try:
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -0,0 +1,213 @@
"""Utilities for agent generation."""
import json
import re
from typing import Any
from backend.data.block import get_blocks
# UUID validation regex
UUID_REGEX = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
)
# Block IDs for various fixes
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
"3b191d9f-356f-482d-8238-ba04b6d18381",
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
"716a67b3-6760-42e7-86dc-18645c6e00fc",
"530cf046-2ce0-4854-ae2c-659db17c7a46",
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
]
def is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID v4."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def _compact_schema(schema: dict) -> dict[str, str]:
"""Extract compact type info from a JSON schema properties dict.
Returns a dict of {field_name: type_string} for essential info only.
"""
props = schema.get("properties", {})
result = {}
for name, prop in props.items():
# Skip internal/complex fields
if name.startswith("_"):
continue
# Get type string
type_str = prop.get("type", "any")
# Handle anyOf/oneOf (optional types)
if "anyOf" in prop:
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
type_str = "|".join(types) if types else "any"
elif "allOf" in prop:
type_str = "object"
# Add array item type if present
if type_str == "array" and "items" in prop:
items = prop["items"]
if isinstance(items, dict):
item_type = items.get("type", "any")
type_str = f"array[{item_type}]"
result[name] = type_str
return result
def get_block_summaries(include_schemas: bool = True) -> str:
"""Generate compact block summaries for prompts.
Args:
include_schemas: Whether to include input/output type info
Returns:
Formatted string of block summaries (compact format)
"""
blocks = get_blocks()
summaries = []
for block_id, block_cls in blocks.items():
block = block_cls()
name = block.name
desc = getattr(block, "description", "") or ""
# Truncate description
if len(desc) > 150:
desc = desc[:147] + "..."
if not include_schemas:
summaries.append(f"- {name} (id: {block_id}): {desc}")
else:
# Compact format with type info only
inputs = {}
outputs = {}
required = []
if hasattr(block, "input_schema"):
try:
schema = block.input_schema.jsonschema()
inputs = _compact_schema(schema)
required = schema.get("required", [])
except Exception:
pass
if hasattr(block, "output_schema"):
try:
schema = block.output_schema.jsonschema()
outputs = _compact_schema(schema)
except Exception:
pass
# Build compact line format
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
req_str = f" req=[{','.join(required)}]" if required else ""
static = " [static]" if getattr(block, "static_output", False) else ""
line = f"- {name} (id: {block_id}): {desc}"
if in_str:
line += f"\n in: {{{in_str}}}{req_str}"
if out_str:
line += f"\n out: {{{out_str}}}{static}"
summaries.append(line)
return "\n".join(summaries)
def get_blocks_info() -> list[dict[str, Any]]:
"""Get block information with schemas for validation and fixing."""
blocks = get_blocks()
blocks_info = []
for block_id, block_cls in blocks.items():
block = block_cls()
blocks_info.append(
{
"id": block_id,
"name": block.name,
"description": getattr(block, "description", ""),
"categories": getattr(block, "categories", []),
"staticOutput": getattr(block, "static_output", False),
"inputSchema": (
block.input_schema.jsonschema()
if hasattr(block, "input_schema")
else {}
),
"outputSchema": (
block.output_schema.jsonschema()
if hasattr(block, "output_schema")
else {}
),
}
)
return blocks_info
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
"""Extract JSON from LLM response (handles markdown code blocks)."""
if not text:
return None
# Try fenced code block
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# Try raw text
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# Try finding {...} span
start = text.find("{")
end = text.rfind("}")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
# Try finding [...] span
start = text.find("[")
end = text.rfind("]")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
return None

View File

@@ -0,0 +1,279 @@
"""Agent validator - Validates agent structure and connections."""
import logging
import re
from typing import Any
from .utils import get_blocks_info
logger = logging.getLogger(__name__)
class AgentValidator:
"""Validator for AutoGPT agents with detailed error reporting."""
def __init__(self):
self.errors: list[str] = []
def add_error(self, error: str) -> None:
"""Add an error message."""
self.errors.append(error)
def validate_block_existence(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate all block IDs exist in the blocks library."""
valid = True
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
node_id = node.get("id")
if not block_id:
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
valid = False
continue
if block_id not in valid_block_ids:
self.add_error(
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
)
valid = False
return valid
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
"""Validate all node IDs referenced in links exist."""
valid = True
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
for link in agent.get("links", []):
link_id = link.get("id", "Unknown")
source_id = link.get("source_id")
sink_id = link.get("sink_id")
if not source_id:
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
valid = False
elif source_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent source_id '{source_id}'."
)
valid = False
if not sink_id:
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
valid = False
elif sink_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
)
valid = False
return valid
def validate_required_inputs(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate required inputs are provided."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
required_inputs = block.get("inputSchema", {}).get("required", [])
input_defaults = node.get("input_default", {})
node_id = node.get("id")
# Get linked inputs
linked_inputs = {
link["sink_name"]
for link in agent.get("links", [])
if link.get("sink_id") == node_id
}
for req_input in required_inputs:
if (
req_input not in input_defaults
and req_input not in linked_inputs
and req_input != "credentials"
):
block_name = block.get("name", "Unknown Block")
self.add_error(
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
)
valid = False
return valid
def validate_data_type_compatibility(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate linked data types are compatible."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
def get_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_type(source_outputs, link.get("source_name", ""))
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
if source_type and sink_type and not are_compatible(source_type, sink_type):
self.add_error(
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
)
valid = False
return valid
def validate_nested_sink_links(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate nested sink links (with _#_ notation)."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(link.get("sink_id"))
if not sink_node:
continue
block = block_map.get(sink_node.get("block_id"))
if not block:
continue
input_props = block.get("inputSchema", {}).get("properties", {})
parent_schema = input_props.get(parent)
if not parent_schema:
self.add_error(
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
)
valid = False
continue
if not parent_schema.get("additionalProperties"):
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and child in parent_schema.get("properties", {})
):
self.add_error(
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
)
valid = False
return valid
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
"""Validate prompts don't have spaces in template variables."""
valid = True
for node in agent.get("nodes", []):
input_default = node.get("input_default", {})
prompt = input_default.get("prompt", "")
if not isinstance(prompt, str):
continue
# Find {{...}} with spaces
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
for match in matches:
content = match.group(1)
if " " in content:
self.add_error(
f"Node '{node.get('id')}' has spaces in template variable: "
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
)
valid = False
return valid
def validate(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Run all validations.
Returns:
Tuple of (is_valid, error_message)
"""
self.errors = []
if blocks_info is None:
blocks_info = get_blocks_info()
checks = [
self.validate_block_existence(agent, blocks_info),
self.validate_link_node_references(agent),
self.validate_required_inputs(agent, blocks_info),
self.validate_data_type_compatibility(agent, blocks_info),
self.validate_nested_sink_links(agent, blocks_info),
self.validate_prompt_spaces(agent),
]
all_passed = all(checks)
if all_passed:
logger.info("Agent validation successful")
return True, None
error_message = "Agent validation failed:\n"
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
return False, error_message
def validate_agent(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Convenience function to validate an agent.
Returns:
Tuple of (is_valid, error_message)
"""
validator = AgentValidator()
return validator.validate(agent, blocks_info)

View File

@@ -5,6 +5,7 @@ import re
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
from langfuse import observe
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
@@ -328,6 +329,7 @@ class AgentOutputTool(BaseTool):
total_executions=len(available_executions) if available_executions else 1, total_executions=len(available_executions) if available_executions else 1,
) )
@observe(as_type="tool", name="view_agent_output")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,

View File

@@ -3,13 +3,17 @@
import logging import logging
from typing import Any from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
AgentGeneratorNotConfiguredError, apply_all_fixes,
decompose_goal, decompose_goal,
generate_agent, generate_agent,
get_blocks_info,
save_agent_to_library, save_agent_to_library,
validate_agent,
) )
from .base import BaseTool from .base import BaseTool
from .models import ( from .models import (
@@ -23,6 +27,9 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum retries for agent generation with validation feedback
MAX_GENERATION_RETRIES = 2
class CreateAgentTool(BaseTool): class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions.""" """Tool for creating agents from natural language descriptions."""
@@ -73,6 +80,7 @@ class CreateAgentTool(BaseTool):
"required": ["description"], "required": ["description"],
} }
@observe(as_type="tool", name="create_agent")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -83,8 +91,9 @@ class CreateAgentTool(BaseTool):
Flow: Flow:
1. Decompose the description into steps (may return clarifying questions) 1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON (external service handles fixing and validation) 2. Generate agent JSON from the steps
3. Preview or save based on the save parameter 3. Apply fixes to correct common LLM errors
4. Preview or save based on the save parameter
""" """
description = kwargs.get("description", "").strip() description = kwargs.get("description", "").strip()
context = kwargs.get("context", "") context = kwargs.get("context", "")
@@ -101,23 +110,18 @@ class CreateAgentTool(BaseTool):
# Step 1: Decompose goal into steps # Step 1: Decompose goal into steps
try: try:
decomposition_result = await decompose_goal(description, context) decomposition_result = await decompose_goal(description, context)
except AgentGeneratorNotConfiguredError: except ValueError as e:
# Handle missing API key or configuration errors
return ErrorResponse( return ErrorResponse(
message=( message=f"Agent generation is not configured: {str(e)}",
"Agent generation is not available. " error="configuration_error",
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id, session_id=session_id,
) )
if decomposition_result is None: if decomposition_result is None:
return ErrorResponse( return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.", message="Failed to analyze the goal. Please try rephrasing.",
error="decomposition_failed", error="Decomposition failed",
details={
"description": description[:100]
}, # Include context for debugging
session_id=session_id, session_id=session_id,
) )
@@ -167,35 +171,72 @@ class CreateAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Step 2: Generate agent JSON (external service handles fixing and validation) # Step 2: Generate agent JSON with retry on validation failure
try: blocks_info = get_blocks_info()
agent_json = await generate_agent(decomposition_result) agent_json = None
except AgentGeneratorNotConfiguredError: validation_errors = None
return ErrorResponse(
message=( for attempt in range(MAX_GENERATION_RETRIES + 1):
"Agent generation is not available. " # Generate agent (include validation errors from previous attempt)
"The Agent Generator service is not configured." if attempt == 0:
), agent_json = await generate_agent(decomposition_result)
error="service_not_configured", else:
session_id=session_id, # Retry with validation error feedback
logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
retry_instructions = {
**decomposition_result,
"previous_errors": validation_errors,
"retry_instructions": (
"The previous generation had validation errors. "
"Please fix these issues in the new generation:\n"
f"{validation_errors}"
),
}
agent_json = await generate_agent(retry_instructions)
if agent_json is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate the agent. Please try again.",
error="Generation failed",
session_id=session_id,
)
continue
# Step 3: Apply fixes to correct common errors
agent_json = apply_all_fixes(agent_json, blocks_info)
# Step 4: Validate the agent
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
if is_valid:
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
) )
if agent_json is None: if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse( # Return error with validation details
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.", return ErrorResponse(
error="generation_failed", message=(
details={ f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
"description": description[:100] f"Please try rephrasing your request or simplify the workflow."
}, # Include context for debugging ),
session_id=session_id, error="validation_failed",
) details={"validation_errors": validation_errors},
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent") agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "") agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", [])) node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", [])) link_count = len(agent_json.get("links", []))
# Step 3: Preview or save # Step 4: Preview or save
if not save: if not save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(

View File

@@ -3,13 +3,18 @@
import logging import logging
from typing import Any from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
AgentGeneratorNotConfiguredError, apply_agent_patch,
apply_all_fixes,
generate_agent_patch, generate_agent_patch,
get_agent_as_json, get_agent_as_json,
get_blocks_info,
save_agent_to_library, save_agent_to_library,
validate_agent,
) )
from .base import BaseTool from .base import BaseTool
from .models import ( from .models import (
@@ -23,6 +28,9 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum retries for patch generation with validation feedback
MAX_GENERATION_RETRIES = 2
class EditAgentTool(BaseTool): class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language.""" """Tool for editing existing agents using natural language."""
@@ -35,7 +43,7 @@ class EditAgentTool(BaseTool):
def description(self) -> str: def description(self) -> str:
return ( return (
"Edit an existing agent from the user's library using natural language. " "Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts." "Generates a patch to update the agent while preserving unchanged parts."
) )
@property @property
@@ -79,6 +87,7 @@ class EditAgentTool(BaseTool):
"required": ["agent_id", "changes"], "required": ["agent_id", "changes"],
} }
@observe(as_type="tool", name="edit_agent")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -89,8 +98,9 @@ class EditAgentTool(BaseTool):
Flow: Flow:
1. Fetch the current agent 1. Fetch the current agent
2. Generate updated agent (external service handles fixing and validation) 2. Generate a patch based on the requested changes
3. Preview or save based on the save parameter 3. Apply the patch to create an updated agent
4. Preview or save based on the save parameter
""" """
agent_id = kwargs.get("agent_id", "").strip() agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip() changes = kwargs.get("changes", "").strip()
@@ -127,59 +137,121 @@ class EditAgentTool(BaseTool):
if context: if context:
update_request = f"{changes}\n\nAdditional context:\n{context}" update_request = f"{changes}\n\nAdditional context:\n{context}"
# Step 2: Generate updated agent (external service handles fixing and validation) # Step 2: Generate patch with retry on validation failure
try: blocks_info = get_blocks_info()
result = await generate_agent_patch(update_request, current_agent) updated_agent = None
except AgentGeneratorNotConfiguredError: validation_errors = None
return ErrorResponse( intent = "Applied requested changes"
message=(
"Agent editing is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if result is None: for attempt in range(MAX_GENERATION_RETRIES + 1):
return ErrorResponse( # Generate patch (include validation errors from previous attempt)
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.", try:
error="update_generation_failed", if attempt == 0:
details={"agent_id": agent_id, "changes": changes[:100]}, patch_result = await generate_agent_patch(
session_id=session_id, update_request, current_agent
)
# Check if LLM returned clarifying questions
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
) )
for q in questions else:
], # Retry with validation error feedback
session_id=session_id, logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
retry_request = (
f"{update_request}\n\n"
f"IMPORTANT: The previous edit had validation errors. "
f"Please fix these issues:\n{validation_errors}"
)
patch_result = await generate_agent_patch(
retry_request, current_agent
)
except ValueError as e:
# Handle missing API key or configuration errors
return ErrorResponse(
message=f"Agent generation is not configured: {str(e)}",
error="configuration_error",
session_id=session_id,
)
if patch_result is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate changes. Please try rephrasing.",
error="Patch generation failed",
session_id=session_id,
)
continue
# Check if LLM returned clarifying questions
if patch_result.get("type") == "clarifying_questions":
questions = patch_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Step 3: Apply patch and fixes
try:
updated_agent = apply_agent_patch(current_agent, patch_result)
updated_agent = apply_all_fixes(updated_agent, blocks_info)
except Exception as e:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message=f"Failed to apply changes: {str(e)}",
error="patch_apply_failed",
details={"exception": str(e)},
session_id=session_id,
)
validation_errors = str(e)
continue
# Step 4: Validate the updated agent
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
if is_valid:
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
intent = patch_result.get("intent", "Applied requested changes")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
) )
# Result is the updated agent JSON if attempt == MAX_GENERATION_RETRIES:
updated_agent = result # Return error with validation details
return ErrorResponse(
message=(
f"Updated agent has validation errors after "
f"{MAX_GENERATION_RETRIES + 1} attempts. "
f"Please try rephrasing your request or simplify the changes."
),
error="validation_failed",
details={"validation_errors": validation_errors},
session_id=session_id,
)
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
assert updated_agent is not None
agent_name = updated_agent.get("name", "Updated Agent") agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "") agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", [])) node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", [])) link_count = len(updated_agent.get("links", []))
# Step 3: Preview or save # Step 5: Preview or save
if not save: if not save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've updated the agent. " f"I've updated the agent. Changes: {intent}. "
f"The agent now has {node_count} blocks. " f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes." f"Review it and call edit_agent with save=true to save the changes."
), ),
@@ -205,7 +277,10 @@ class EditAgentTool(BaseTool):
) )
return AgentSavedResponse( return AgentSavedResponse(
message=f"Updated agent '{created_graph.name}' has been saved to your library!", message=(
f"Updated agent '{created_graph.name}' has been saved to your library! "
f"Changes: {intent}"
),
agent_id=created_graph.id, agent_id=created_graph.id,
agent_name=created_graph.name, agent_name=created_graph.name,
library_agent_id=library_agent.id, library_agent_id=library_agent.id,

View File

@@ -2,6 +2,8 @@
from typing import Any from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
@@ -35,6 +37,7 @@ class FindAgentTool(BaseTool):
"required": ["query"], "required": ["query"],
} }
@observe(as_type="tool", name="find_agent")
async def _execute( async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase: ) -> ToolResponseBase:

View File

@@ -1,6 +1,7 @@
import logging import logging
from typing import Any from typing import Any
from langfuse import observe
from prisma.enums import ContentType from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
@@ -55,6 +56,7 @@ class FindBlockTool(BaseTool):
def requires_auth(self) -> bool: def requires_auth(self) -> bool:
return True return True
@observe(as_type="tool", name="find_block")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -107,8 +109,7 @@ class FindBlockTool(BaseTool):
block_id = result["content_id"] block_id = result["content_id"]
block = get_block(block_id) block = get_block(block_id)
# Skip disabled blocks if block:
if block and not block.disabled:
# Get input/output schemas # Get input/output schemas
input_schema = {} input_schema = {}
output_schema = {} output_schema = {}

View File

@@ -2,6 +2,8 @@
from typing import Any from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
@@ -41,6 +43,7 @@ class FindLibraryAgentTool(BaseTool):
def requires_auth(self) -> bool: def requires_auth(self) -> bool:
return True return True
@observe(as_type="tool", name="find_library_agent")
async def _execute( async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase: ) -> ToolResponseBase:

View File

@@ -4,6 +4,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import ( from backend.api.features.chat.tools.models import (
@@ -71,6 +73,7 @@ class GetDocPageTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}" return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="get_doc_page")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,

View File

@@ -3,14 +3,11 @@
import logging import logging
from typing import Any from typing import Any
from langfuse import observe
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import (
track_agent_run_success,
track_agent_scheduled,
)
from backend.api.features.library import db as library_db from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
@@ -158,6 +155,7 @@ class RunAgentTool(BaseTool):
"""All operations require authentication.""" """All operations require authentication."""
return True return True
@observe(as_type="tool", name="run_agent")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -455,16 +453,6 @@ class RunAgentTool(BaseTool):
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1 session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
) )
# Track in PostHog
track_agent_run_success(
user_id=user_id,
session_id=session_id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
execution_id=execution.id,
library_agent_id=library_agent.id,
)
library_agent_link = f"/library/agents/{library_agent.id}" library_agent_link = f"/library/agents/{library_agent.id}"
return ExecutionStartedResponse( return ExecutionStartedResponse(
message=( message=(
@@ -546,18 +534,6 @@ class RunAgentTool(BaseTool):
session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1 session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1
) )
# Track in PostHog
track_agent_scheduled(
user_id=user_id,
session_id=session_id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
schedule_id=result.id,
schedule_name=schedule_name,
cron=cron,
library_agent_id=library_agent.id,
)
library_agent_link = f"/library/agents/{library_agent.id}" library_agent_link = f"/library/agents/{library_agent.id}"
return ExecutionStartedResponse( return ExecutionStartedResponse(
message=( message=(

View File

@@ -29,7 +29,7 @@ def mock_embedding_functions():
yield yield
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent(setup_test_data): async def test_run_agent(setup_test_data):
"""Test that the run_agent tool successfully executes an approved agent""" """Test that the run_agent tool successfully executes an approved agent"""
# Use test data from fixture # Use test data from fixture
@@ -70,7 +70,7 @@ async def test_run_agent(setup_test_data):
assert result_data["graph_name"] == "Test Agent" assert result_data["graph_name"] == "Test Agent"
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_missing_inputs(setup_test_data): async def test_run_agent_missing_inputs(setup_test_data):
"""Test that the run_agent tool returns error when inputs are missing""" """Test that the run_agent tool returns error when inputs are missing"""
# Use test data from fixture # Use test data from fixture
@@ -106,7 +106,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
assert "message" in result_data assert "message" in result_data
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_invalid_agent_id(setup_test_data): async def test_run_agent_invalid_agent_id(setup_test_data):
"""Test that the run_agent tool returns error for invalid agent ID""" """Test that the run_agent tool returns error for invalid agent ID"""
# Use test data from fixture # Use test data from fixture
@@ -141,7 +141,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
) )
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_with_llm_credentials(setup_llm_test_data): async def test_run_agent_with_llm_credentials(setup_llm_test_data):
"""Test that run_agent works with an agent requiring LLM credentials""" """Test that run_agent works with an agent requiring LLM credentials"""
# Use test data from fixture # Use test data from fixture
@@ -185,7 +185,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
assert result_data["graph_name"] == "LLM Test Agent" assert result_data["graph_name"] == "LLM Test Agent"
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data): async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
"""Test that run_agent returns available inputs when called without inputs or use_defaults.""" """Test that run_agent returns available inputs when called without inputs or use_defaults."""
user = setup_test_data["user"] user = setup_test_data["user"]
@@ -219,7 +219,7 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
assert "inputs" in result_data["message"].lower() assert "inputs" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_with_use_defaults(setup_test_data): async def test_run_agent_with_use_defaults(setup_test_data):
"""Test that run_agent executes successfully with use_defaults=True.""" """Test that run_agent executes successfully with use_defaults=True."""
user = setup_test_data["user"] user = setup_test_data["user"]
@@ -251,7 +251,7 @@ async def test_run_agent_with_use_defaults(setup_test_data):
assert result_data["graph_id"] == graph.id assert result_data["graph_id"] == graph.id
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_missing_credentials(setup_firecrawl_test_data): async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
"""Test that run_agent returns setup_requirements when credentials are missing.""" """Test that run_agent returns setup_requirements when credentials are missing."""
user = setup_firecrawl_test_data["user"] user = setup_firecrawl_test_data["user"]
@@ -285,7 +285,7 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0 assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_invalid_slug_format(setup_test_data): async def test_run_agent_invalid_slug_format(setup_test_data):
"""Test that run_agent returns error for invalid slug format (no slash).""" """Test that run_agent returns error for invalid slug format (no slash)."""
user = setup_test_data["user"] user = setup_test_data["user"]
@@ -313,7 +313,7 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
assert "username/agent-name" in result_data["message"] assert "username/agent-name" in result_data["message"]
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_unauthenticated(): async def test_run_agent_unauthenticated():
"""Test that run_agent returns need_login for unauthenticated users.""" """Test that run_agent returns need_login for unauthenticated users."""
tool = RunAgentTool() tool = RunAgentTool()
@@ -340,7 +340,7 @@ async def test_run_agent_unauthenticated():
assert "sign in" in result_data["message"].lower() assert "sign in" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_schedule_without_cron(setup_test_data): async def test_run_agent_schedule_without_cron(setup_test_data):
"""Test that run_agent returns error when scheduling without cron expression.""" """Test that run_agent returns error when scheduling without cron expression."""
user = setup_test_data["user"] user = setup_test_data["user"]
@@ -372,7 +372,7 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
assert "cron" in result_data["message"].lower() assert "cron" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(scope="session")
async def test_run_agent_schedule_without_name(setup_test_data): async def test_run_agent_schedule_without_name(setup_test_data):
"""Test that run_agent returns error when scheduling without schedule_name.""" """Test that run_agent returns error when scheduling without schedule_name."""
user = setup_test_data["user"] user = setup_test_data["user"]

View File

@@ -4,6 +4,8 @@ import logging
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block from backend.data.block import get_block
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
@@ -128,6 +130,7 @@ class RunBlockTool(BaseTool):
return matched_credentials, missing_credentials return matched_credentials, missing_credentials
@observe(as_type="tool", name="run_block")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -176,11 +179,6 @@ class RunBlockTool(BaseTool):
message=f"Block '{block_id}' not found", message=f"Block '{block_id}' not found",
session_id=session_id, session_id=session_id,
) )
if block.disabled:
return ErrorResponse(
message=f"Block '{block_id}' is disabled",
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")

View File

@@ -3,6 +3,7 @@
import logging import logging
from typing import Any from typing import Any
from langfuse import observe
from prisma.enums import ContentType from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
@@ -87,6 +88,7 @@ class SearchDocsTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}" return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="search_docs")
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,

View File

@@ -1,250 +0,0 @@
"""PostHog analytics tracking for the chat system."""
import atexit
import logging
from typing import Any
from posthog import Posthog
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
# PostHog client instance (lazily initialized)
_posthog_client: Posthog | None = None
def _shutdown_posthog() -> None:
"""Flush and shutdown PostHog client on process exit."""
if _posthog_client is not None:
_posthog_client.flush()
_posthog_client.shutdown()
atexit.register(_shutdown_posthog)
def _get_posthog_client() -> Posthog | None:
"""Get or create the PostHog client instance."""
global _posthog_client
if _posthog_client is not None:
return _posthog_client
if not settings.secrets.posthog_api_key:
logger.debug("PostHog API key not configured, analytics disabled")
return None
_posthog_client = Posthog(
settings.secrets.posthog_api_key,
host=settings.secrets.posthog_host,
)
logger.info(
f"PostHog client initialized with host: {settings.secrets.posthog_host}"
)
return _posthog_client
def _get_base_properties() -> dict[str, Any]:
"""Get base properties included in all events."""
return {
"environment": settings.config.app_env.value,
"source": "chat_copilot",
}
def track_user_message(
user_id: str | None,
session_id: str,
message_length: int,
) -> None:
"""Track when a user sends a message in chat.
Args:
user_id: The user's ID (or None for anonymous)
session_id: The chat session ID
message_length: Length of the user's message
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"message_length": message_length,
}
client.capture(
distinct_id=user_id or f"anonymous_{session_id}",
event="copilot_message_sent",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track user message: {e}")
def track_tool_called(
user_id: str | None,
session_id: str,
tool_name: str,
tool_call_id: str,
) -> None:
"""Track when a tool is called in chat.
Args:
user_id: The user's ID (or None for anonymous)
session_id: The chat session ID
tool_name: Name of the tool being called
tool_call_id: Unique ID of the tool call
"""
client = _get_posthog_client()
if not client:
logger.info("PostHog client not available for tool tracking")
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"tool_name": tool_name,
"tool_call_id": tool_call_id,
}
distinct_id = user_id or f"anonymous_{session_id}"
logger.info(
f"Sending copilot_tool_called event to PostHog: distinct_id={distinct_id}, "
f"tool_name={tool_name}"
)
client.capture(
distinct_id=distinct_id,
event="copilot_tool_called",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track tool call: {e}")
def track_agent_run_success(
user_id: str,
session_id: str,
graph_id: str,
graph_name: str,
execution_id: str,
library_agent_id: str,
) -> None:
"""Track when an agent is successfully run.
Args:
user_id: The user's ID
session_id: The chat session ID
graph_id: ID of the agent graph
graph_name: Name of the agent
execution_id: ID of the execution
library_agent_id: ID of the library agent
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"graph_id": graph_id,
"graph_name": graph_name,
"execution_id": execution_id,
"library_agent_id": library_agent_id,
}
client.capture(
distinct_id=user_id,
event="copilot_agent_run_success",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track agent run: {e}")
def track_agent_scheduled(
user_id: str,
session_id: str,
graph_id: str,
graph_name: str,
schedule_id: str,
schedule_name: str,
cron: str,
library_agent_id: str,
) -> None:
"""Track when an agent is successfully scheduled.
Args:
user_id: The user's ID
session_id: The chat session ID
graph_id: ID of the agent graph
graph_name: Name of the agent
schedule_id: ID of the schedule
schedule_name: Name of the schedule
cron: Cron expression for the schedule
library_agent_id: ID of the library agent
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"graph_id": graph_id,
"graph_name": graph_name,
"schedule_id": schedule_id,
"schedule_name": schedule_name,
"cron": cron,
"library_agent_id": library_agent_id,
}
client.capture(
distinct_id=user_id,
event="copilot_agent_scheduled",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track agent schedule: {e}")
def track_trigger_setup(
user_id: str,
session_id: str,
graph_id: str,
graph_name: str,
trigger_type: str,
library_agent_id: str,
) -> None:
"""Track when a trigger is set up for an agent.
Args:
user_id: The user's ID
session_id: The chat session ID
graph_id: ID of the agent graph
graph_name: Name of the agent
trigger_type: Type of trigger (e.g., 'webhook')
library_agent_id: ID of the library agent
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"graph_id": graph_id,
"graph_name": graph_name,
"trigger_type": trigger_type,
"library_agent_id": library_agent_id,
}
client.capture(
distinct_id=user_id,
event="copilot_trigger_setup",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track trigger setup: {e}")

View File

@@ -23,7 +23,6 @@ class PendingHumanReviewModel(BaseModel):
id: Unique identifier for the review record id: Unique identifier for the review record
user_id: ID of the user who must perform the review user_id: ID of the user who must perform the review
node_exec_id: ID of the node execution that created this review node_exec_id: ID of the node execution that created this review
node_id: ID of the node definition (for grouping reviews from same node)
graph_exec_id: ID of the graph execution containing the node graph_exec_id: ID of the graph execution containing the node
graph_id: ID of the graph template being executed graph_id: ID of the graph template being executed
graph_version: Version number of the graph template graph_version: Version number of the graph template
@@ -38,10 +37,6 @@ class PendingHumanReviewModel(BaseModel):
""" """
node_exec_id: str = Field(description="Node execution ID (primary key)") node_exec_id: str = Field(description="Node execution ID (primary key)")
node_id: str = Field(
description="Node definition ID (for grouping)",
default="", # Temporary default for test compatibility
)
user_id: str = Field(description="User ID associated with the review") user_id: str = Field(description="User ID associated with the review")
graph_exec_id: str = Field(description="Graph execution ID") graph_exec_id: str = Field(description="Graph execution ID")
graph_id: str = Field(description="Graph ID") graph_id: str = Field(description="Graph ID")
@@ -71,9 +66,7 @@ class PendingHumanReviewModel(BaseModel):
) )
@classmethod @classmethod
def from_db( def from_db(cls, review: "PendingHumanReview") -> "PendingHumanReviewModel":
cls, review: "PendingHumanReview", node_id: str
) -> "PendingHumanReviewModel":
""" """
Convert a database model to a response model. Convert a database model to a response model.
@@ -81,14 +74,9 @@ class PendingHumanReviewModel(BaseModel):
payload, instructions, and editable flag. payload, instructions, and editable flag.
Handles invalid data gracefully by using safe defaults. Handles invalid data gracefully by using safe defaults.
Args:
review: Database review object
node_id: Node definition ID (fetched from NodeExecution)
""" """
return cls( return cls(
node_exec_id=review.nodeExecId, node_exec_id=review.nodeExecId,
node_id=node_id,
user_id=review.userId, user_id=review.userId,
graph_exec_id=review.graphExecId, graph_exec_id=review.graphExecId,
graph_id=review.graphId, graph_id=review.graphId,
@@ -119,13 +107,6 @@ class ReviewItem(BaseModel):
reviewed_data: SafeJsonData | None = Field( reviewed_data: SafeJsonData | None = Field(
None, description="Optional edited data (ignored if approved=False)" None, description="Optional edited data (ignored if approved=False)"
) )
auto_approve_future: bool = Field(
default=False,
description=(
"If true and this review is approved, future executions of this same "
"block (node) will be automatically approved. This only affects approved reviews."
),
)
@field_validator("reviewed_data") @field_validator("reviewed_data")
@classmethod @classmethod
@@ -193,9 +174,6 @@ class ReviewRequest(BaseModel):
This request must include ALL pending reviews for a graph execution. This request must include ALL pending reviews for a graph execution.
Each review will be either approved (with optional data modifications) Each review will be either approved (with optional data modifications)
or rejected (data ignored). The execution will resume only after ALL reviews are processed. or rejected (data ignored). The execution will resume only after ALL reviews are processed.
Each review item can individually specify whether to auto-approve future executions
of the same block via the `auto_approve_future` field on ReviewItem.
""" """
reviews: List[ReviewItem] = Field( reviews: List[ReviewItem] = Field(

View File

@@ -1,27 +1,17 @@
import asyncio
import logging import logging
from typing import Any, List from typing import List
import autogpt_libs.auth as autogpt_auth_lib import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from backend.data.execution import ( from backend.data.execution import get_graph_execution_meta
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import ( from backend.data.human_review import (
create_auto_approval_record,
get_pending_reviews_for_execution, get_pending_reviews_for_execution,
get_pending_reviews_for_user, get_pending_reviews_for_user,
get_reviews_by_node_exec_ids,
has_pending_reviews_for_graph_exec, has_pending_reviews_for_graph_exec,
process_all_reviews_for_execution, process_all_reviews_for_execution,
) )
from backend.data.model import USER_TIMEZONE_NOT_SET
from backend.data.user import get_user_by_id
from backend.executor.utils import add_graph_execution from backend.executor.utils import add_graph_execution
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
@@ -137,70 +127,17 @@ async def process_review_action(
detail="At least one review must be provided", detail="At least one review must be provided",
) )
# Batch fetch all requested reviews (regardless of status for idempotent handling) # Build review decisions map
reviews_map = await get_reviews_by_node_exec_ids(
list(all_request_node_ids), user_id
)
# Validate all reviews were found (must exist, any status is OK for now)
missing_ids = all_request_node_ids - set(reviews_map.keys())
if missing_ids:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Review(s) not found: {', '.join(missing_ids)}",
)
# Validate all reviews belong to the same execution
graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()}
if len(graph_exec_ids) > 1:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="All reviews in a single request must belong to the same execution.",
)
graph_exec_id = next(iter(graph_exec_ids))
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
review_decisions = {} review_decisions = {}
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
for review in request.reviews: for review in request.reviews:
review_status = ( review_status = (
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
) )
# If this review requested auto-approval, don't allow data modifications
reviewed_data = None if review.auto_approve_future else review.reviewed_data
review_decisions[review.node_exec_id] = ( review_decisions[review.node_exec_id] = (
review_status, review_status,
reviewed_data, review.reviewed_data,
review.message, review.message,
) )
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
# Process all reviews # Process all reviews
updated_reviews = await process_all_reviews_for_execution( updated_reviews = await process_all_reviews_for_execution(
@@ -208,87 +145,6 @@ async def process_review_action(
review_decisions=review_decisions, review_decisions=review_decisions,
) )
# Create auto-approval records for approved reviews that requested it
# Deduplicate by node_id to avoid race conditions when multiple reviews
# for the same node are processed in parallel
async def create_auto_approval_for_node(
node_id: str, review_result
) -> tuple[str, bool]:
"""
Create auto-approval record for a node.
Returns (node_id, success) tuple for tracking failures.
"""
try:
await create_auto_approval_record(
user_id=user_id,
graph_exec_id=review_result.graph_exec_id,
graph_id=review_result.graph_id,
graph_version=review_result.graph_version,
node_id=node_id,
payload=review_result.payload,
)
return (node_id, True)
except Exception as e:
logger.error(
f"Failed to create auto-approval record for node {node_id}",
exc_info=e,
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
if review_result.status == ReviewStatus.APPROVED
and auto_approve_requests.get(node_exec_id, False)
]
# Batch-fetch node executions to get node_ids
nodes_needing_auto_approval: dict[str, Any] = {}
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
*[
create_auto_approval_for_node(node_id, review_result)
for node_id, review_result in nodes_needing_auto_approval.items()
],
return_exceptions=True,
)
# Count auto-approval failures
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results # Count results
approved_count = sum( approved_count = sum(
1 1
@@ -301,53 +157,30 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED if review.status == ReviewStatus.REJECTED
) )
# Resume execution only if ALL pending reviews for this execution have been processed # Resume execution if we processed some reviews
if updated_reviews: if updated_reviews:
# Get graph execution ID from any processed review
first_review = next(iter(updated_reviews.values()))
graph_exec_id = first_review.graph_exec_id
# Check if any pending reviews remain for this execution
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id) still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending: if not still_has_pending:
# Get the graph_id from any processed review # Resume execution
first_review = next(iter(updated_reviews.values()))
try: try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)
execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
user_timezone=user_timezone,
)
await add_graph_execution( await add_graph_execution(
graph_id=first_review.graph_id, graph_id=first_review.graph_id,
user_id=user_id, user_id=user_id,
graph_exec_id=graph_exec_id, graph_exec_id=graph_exec_id,
execution_context=execution_context,
) )
logger.info(f"Resumed execution {graph_exec_id}") logger.info(f"Resumed execution {graph_exec_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}") logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
# Build error message if auto-approvals failed
error_message = None
if auto_approval_failed_count > 0:
error_message = (
f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. "
f"You may need to manually approve these reviews in future executions."
)
return ReviewResponse( return ReviewResponse(
approved_count=approved_count, approved_count=approved_count,
rejected_count=rejected_count, rejected_count=rejected_count,
failed_count=auto_approval_failed_count, failed_count=0,
error=error_message, error=None,
) )

View File

@@ -401,11 +401,27 @@ async def add_generated_agent_image(
) )
def _initialize_graph_settings(graph: graph_db.GraphModel) -> GraphSettings:
"""
Initialize GraphSettings based on graph content.
Args:
graph: The graph to analyze
Returns:
GraphSettings with appropriate human_in_the_loop_safe_mode value
"""
if graph.has_human_in_the_loop:
# Graph has HITL blocks - set safe mode to True by default
return GraphSettings(human_in_the_loop_safe_mode=True)
else:
# Graph has no HITL blocks - keep None
return GraphSettings(human_in_the_loop_safe_mode=None)
async def create_library_agent( async def create_library_agent(
graph: graph_db.GraphModel, graph: graph_db.GraphModel,
user_id: str, user_id: str,
hitl_safe_mode: bool = True,
sensitive_action_safe_mode: bool = False,
create_library_agents_for_sub_graphs: bool = True, create_library_agents_for_sub_graphs: bool = True,
) -> list[library_model.LibraryAgent]: ) -> list[library_model.LibraryAgent]:
""" """
@@ -414,8 +430,6 @@ async def create_library_agent(
Args: Args:
agent: The agent/Graph to add to the library. agent: The agent/Graph to add to the library.
user_id: The user to whom the agent will be added. user_id: The user to whom the agent will be added.
hitl_safe_mode: Whether HITL blocks require manual review (default True).
sensitive_action_safe_mode: Whether sensitive action blocks require review.
create_library_agents_for_sub_graphs: If True, creates LibraryAgent records for sub-graphs as well. create_library_agents_for_sub_graphs: If True, creates LibraryAgent records for sub-graphs as well.
Returns: Returns:
@@ -451,11 +465,7 @@ async def create_library_agent(
} }
}, },
settings=SafeJson( settings=SafeJson(
GraphSettings.from_graph( _initialize_graph_settings(graph_entry).model_dump()
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
), ),
), ),
include=library_agent_include( include=library_agent_include(
@@ -583,13 +593,7 @@ async def update_library_agent(
) )
update_fields["isDeleted"] = is_deleted update_fields["isDeleted"] = is_deleted
if settings is not None: if settings is not None:
existing_agent = await get_library_agent(id=library_agent_id, user_id=user_id) update_fields["settings"] = SafeJson(settings.model_dump())
current_settings_dict = (
existing_agent.settings.model_dump() if existing_agent.settings else {}
)
new_settings = settings.model_dump(exclude_unset=True)
merged_settings = {**current_settings_dict, **new_settings}
update_fields["settings"] = SafeJson(merged_settings)
try: try:
# If graph_version is provided, update to that specific version # If graph_version is provided, update to that specific version
@@ -623,6 +627,33 @@ async def update_library_agent(
raise DatabaseError("Failed to update library agent") from e raise DatabaseError("Failed to update library agent") from e
async def update_library_agent_settings(
user_id: str,
agent_id: str,
settings: GraphSettings,
) -> library_model.LibraryAgent:
"""
Updates the settings for a specific LibraryAgent.
Args:
user_id: The owner of the LibraryAgent.
agent_id: The ID of the LibraryAgent to update.
settings: New GraphSettings to apply.
Returns:
The updated LibraryAgent.
Raises:
NotFoundError: If the specified LibraryAgent does not exist.
DatabaseError: If there's an error in the update operation.
"""
return await update_library_agent(
library_agent_id=agent_id,
user_id=user_id,
settings=settings,
)
async def delete_library_agent( async def delete_library_agent(
library_agent_id: str, user_id: str, soft_delete: bool = True library_agent_id: str, user_id: str, soft_delete: bool = True
) -> None: ) -> None:
@@ -807,7 +838,7 @@ async def add_store_agent_to_library(
"isCreatedByUser": False, "isCreatedByUser": False,
"useGraphIsActiveVersion": False, "useGraphIsActiveVersion": False,
"settings": SafeJson( "settings": SafeJson(
GraphSettings.from_graph(graph_model).model_dump() _initialize_graph_settings(graph_model).model_dump()
), ),
}, },
include=library_agent_include( include=library_agent_include(
@@ -1197,15 +1228,8 @@ async def fork_library_agent(
) )
new_graph = await on_graph_activate(new_graph, user_id=user_id) new_graph = await on_graph_activate(new_graph, user_id=user_id)
# Create a library agent for the new graph, preserving safe mode settings # Create a library agent for the new graph
return ( return (await create_library_agent(new_graph, user_id))[0]
await create_library_agent(
new_graph,
user_id,
hitl_safe_mode=original_agent.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=original_agent.settings.sensitive_action_safe_mode,
)
)[0]
except prisma.errors.PrismaError as e: except prisma.errors.PrismaError as e:
logger.error(f"Database error cloning library agent: {e}") logger.error(f"Database error cloning library agent: {e}")
raise DatabaseError("Failed to fork library agent") from e raise DatabaseError("Failed to fork library agent") from e

View File

@@ -73,12 +73,6 @@ class LibraryAgent(pydantic.BaseModel):
has_external_trigger: bool = pydantic.Field( has_external_trigger: bool = pydantic.Field(
description="Whether the agent has an external trigger (e.g. webhook) node" description="Whether the agent has an external trigger (e.g. webhook) node"
) )
has_human_in_the_loop: bool = pydantic.Field(
description="Whether the agent has human-in-the-loop blocks"
)
has_sensitive_action: bool = pydantic.Field(
description="Whether the agent has sensitive action blocks"
)
trigger_setup_info: Optional[GraphTriggerInfo] = None trigger_setup_info: Optional[GraphTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs) # Indicates whether there's a new output (based on recent runs)
@@ -186,8 +180,6 @@ class LibraryAgent(pydantic.BaseModel):
graph.credentials_input_schema if sub_graphs is not None else None graph.credentials_input_schema if sub_graphs is not None else None
), ),
has_external_trigger=graph.has_external_trigger, has_external_trigger=graph.has_external_trigger,
has_human_in_the_loop=graph.has_human_in_the_loop,
has_sensitive_action=graph.has_sensitive_action,
trigger_setup_info=graph.trigger_setup_info, trigger_setup_info=graph.trigger_setup_info,
new_output=new_output, new_output=new_output,
can_access_graph=can_access_graph, can_access_graph=can_access_graph,

View File

@@ -52,8 +52,6 @@ async def test_get_library_agents_success(
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}}, credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False, has_external_trigger=False,
has_human_in_the_loop=False,
has_sensitive_action=False,
status=library_model.LibraryAgentStatus.COMPLETED, status=library_model.LibraryAgentStatus.COMPLETED,
recommended_schedule_cron=None, recommended_schedule_cron=None,
new_output=False, new_output=False,
@@ -77,8 +75,6 @@ async def test_get_library_agents_success(
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}}, credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False, has_external_trigger=False,
has_human_in_the_loop=False,
has_sensitive_action=False,
status=library_model.LibraryAgentStatus.COMPLETED, status=library_model.LibraryAgentStatus.COMPLETED,
recommended_schedule_cron=None, recommended_schedule_cron=None,
new_output=False, new_output=False,
@@ -154,8 +150,6 @@ async def test_get_favorite_library_agents_success(
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}}, credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False, has_external_trigger=False,
has_human_in_the_loop=False,
has_sensitive_action=False,
status=library_model.LibraryAgentStatus.COMPLETED, status=library_model.LibraryAgentStatus.COMPLETED,
recommended_schedule_cron=None, recommended_schedule_cron=None,
new_output=False, new_output=False,
@@ -224,8 +218,6 @@ def test_add_agent_to_library_success(
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}}, credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False, has_external_trigger=False,
has_human_in_the_loop=False,
has_sensitive_action=False,
status=library_model.LibraryAgentStatus.COMPLETED, status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False, new_output=False,
can_access_graph=True, can_access_graph=True,

View File

@@ -20,7 +20,6 @@ from typing import AsyncGenerator
import httpx import httpx
import pytest import pytest
import pytest_asyncio
from autogpt_libs.api_key.keysmith import APIKeySmith from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission from prisma.enums import APIKeyPermission
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
@@ -39,13 +38,13 @@ keysmith = APIKeySmith()
# ============================================================================ # ============================================================================
@pytest.fixture(scope="session") @pytest.fixture
def test_user_id() -> str: def test_user_id() -> str:
"""Test user ID for OAuth tests.""" """Test user ID for OAuth tests."""
return str(uuid.uuid4()) return str(uuid.uuid4())
@pytest_asyncio.fixture(scope="session", loop_scope="session") @pytest.fixture
async def test_user(server, test_user_id: str): async def test_user(server, test_user_id: str):
"""Create a test user in the database.""" """Create a test user in the database."""
await PrismaUser.prisma().create( await PrismaUser.prisma().create(
@@ -68,7 +67,7 @@ async def test_user(server, test_user_id: str):
await PrismaUser.prisma().delete(where={"id": test_user_id}) await PrismaUser.prisma().delete(where={"id": test_user_id})
@pytest_asyncio.fixture @pytest.fixture
async def test_oauth_app(test_user: str): async def test_oauth_app(test_user: str):
"""Create a test OAuth application in the database.""" """Create a test OAuth application in the database."""
app_id = str(uuid.uuid4()) app_id = str(uuid.uuid4())
@@ -123,7 +122,7 @@ def pkce_credentials() -> tuple[str, str]:
return generate_pkce() return generate_pkce()
@pytest_asyncio.fixture @pytest.fixture
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]: async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
""" """
Create an async HTTP client that talks directly to the FastAPI app. Create an async HTTP client that talks directly to the FastAPI app.
@@ -288,7 +287,7 @@ async def test_authorize_invalid_client_returns_error(
assert query_params["error"][0] == "invalid_client" assert query_params["error"][0] == "invalid_client"
@pytest_asyncio.fixture @pytest.fixture
async def inactive_oauth_app(test_user: str): async def inactive_oauth_app(test_user: str):
"""Create an inactive test OAuth application in the database.""" """Create an inactive test OAuth application in the database."""
app_id = str(uuid.uuid4()) app_id = str(uuid.uuid4())
@@ -1005,7 +1004,7 @@ async def test_token_refresh_revoked(
assert "revoked" in response.json()["detail"].lower() assert "revoked" in response.json()["detail"].lower()
@pytest_asyncio.fixture @pytest.fixture
async def other_oauth_app(test_user: str): async def other_oauth_app(test_user: str):
"""Create a second OAuth application for cross-app tests.""" """Create a second OAuth application for cross-app tests."""
app_id = str(uuid.uuid4()) app_id = str(uuid.uuid4())

View File

@@ -188,10 +188,6 @@ class BlockHandler(ContentHandler):
try: try:
block_instance = block_cls() block_instance = block_cls()
# Skip disabled blocks - they shouldn't be indexed
if block_instance.disabled:
continue
# Build searchable text from block metadata # Build searchable text from block metadata
parts = [] parts = []
if hasattr(block_instance, "name") and block_instance.name: if hasattr(block_instance, "name") and block_instance.name:
@@ -252,19 +248,12 @@ class BlockHandler(ContentHandler):
from backend.data.block import get_blocks from backend.data.block import get_blocks
all_blocks = get_blocks() all_blocks = get_blocks()
total_blocks = len(all_blocks)
# Filter out disabled blocks - they're not indexed
enabled_block_ids = [
block_id
for block_id, block_cls in all_blocks.items()
if not block_cls().disabled
]
total_blocks = len(enabled_block_ids)
if total_blocks == 0: if total_blocks == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0} return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = enabled_block_ids block_ids = list(all_blocks.keys())
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))]) placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema( embedded_result = await query_raw_with_schema(

View File

@@ -81,7 +81,6 @@ async def test_block_handler_get_missing_items(mocker):
mock_block_instance.name = "Calculator Block" mock_block_instance.name = "Calculator Block"
mock_block_instance.description = "Performs calculations" mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")] mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_block_instance.input_schema.model_json_schema.return_value = { mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}} "properties": {"expression": {"description": "Math expression to evaluate"}}
} }
@@ -117,18 +116,11 @@ async def test_block_handler_get_stats(mocker):
"""Test BlockHandler returns correct stats.""" """Test BlockHandler returns correct stats."""
handler = BlockHandler() handler = BlockHandler()
# Mock get_blocks - each block class returns an instance with disabled=False # Mock get_blocks
def make_mock_block_class():
mock_class = MagicMock()
mock_instance = MagicMock()
mock_instance.disabled = False
mock_class.return_value = mock_instance
return mock_class
mock_blocks = { mock_blocks = {
"block-1": make_mock_block_class(), "block-1": MagicMock(),
"block-2": make_mock_block_class(), "block-2": MagicMock(),
"block-3": make_mock_block_class(), "block-3": MagicMock(),
} }
# Mock embedded count query (2 blocks have embeddings) # Mock embedded count query (2 blocks have embeddings)
@@ -317,7 +309,6 @@ async def test_block_handler_handles_missing_attributes():
mock_block_class = MagicMock() mock_block_class = MagicMock()
mock_block_instance = MagicMock() mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block" mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
# No description, categories, or schema # No description, categories, or schema
del mock_block_instance.description del mock_block_instance.description
del mock_block_instance.categories del mock_block_instance.categories
@@ -351,7 +342,6 @@ async def test_block_handler_skips_failed_blocks():
good_instance.name = "Good Block" good_instance.name = "Good Block"
good_instance.description = "Works fine" good_instance.description = "Works fine"
good_instance.categories = [] good_instance.categories = []
good_instance.disabled = False
good_block.return_value = good_instance good_block.return_value = good_instance
bad_block = MagicMock() bad_block = MagicMock()

View File

@@ -22,6 +22,7 @@ from backend.data.notifications import (
AgentApprovalData, AgentApprovalData,
AgentRejectionData, AgentRejectionData,
NotificationEventModel, NotificationEventModel,
WaitlistLaunchData,
) )
from backend.notifications.notifications import queue_notification_async from backend.notifications.notifications import queue_notification_async
from backend.util.exceptions import DatabaseError from backend.util.exceptions import DatabaseError
@@ -1552,7 +1553,7 @@ async def review_store_submission(
# Generate embedding for approved listing (blocking - admin operation) # Generate embedding for approved listing (blocking - admin operation)
# Inside transaction: if embedding fails, entire transaction rolls back # Inside transaction: if embedding fails, entire transaction rolls back
await ensure_embedding( embedding_success = await ensure_embedding(
version_id=store_listing_version_id, version_id=store_listing_version_id,
name=store_listing_version.name, name=store_listing_version.name,
description=store_listing_version.description, description=store_listing_version.description,
@@ -1560,6 +1561,12 @@ async def review_store_submission(
categories=store_listing_version.categories or [], categories=store_listing_version.categories or [],
tx=tx, tx=tx,
) )
if not embedding_success:
raise ValueError(
f"Failed to generate embedding for listing {store_listing_version_id}. "
"This is likely due to OpenAI API being unavailable. "
"Please try again later or contact support if the issue persists."
)
await prisma.models.StoreListing.prisma(tx).update( await prisma.models.StoreListing.prisma(tx).update(
where={"id": store_listing_version.StoreListing.id}, where={"id": store_listing_version.StoreListing.id},
@@ -1711,6 +1718,29 @@ async def review_store_submission(
# Don't fail the review process if email sending fails # Don't fail the review process if email sending fails
pass pass
# Notify waitlist users if this is an approval and has a linked waitlist
if is_approved and submission.StoreListing:
try:
frontend_base_url = (
settings.config.frontend_base_url
or settings.config.platform_base_url
)
store_agent = (
await prisma.models.StoreAgent.prisma().find_first_or_raise(
where={"storeListingVersionId": submission.id}
)
)
creator_username = store_agent.creator_username or "unknown"
store_url = f"{frontend_base_url}/marketplace/agent/{creator_username}/{store_agent.slug}"
await notify_waitlist_users_on_launch(
store_listing_id=submission.StoreListing.id,
agent_name=submission.name,
store_url=store_url,
)
except Exception as e:
logger.error(f"Failed to notify waitlist users on agent approval: {e}")
# Don't fail the approval process
# Convert to Pydantic model for consistency # Convert to Pydantic model for consistency
return store_model.StoreSubmission( return store_model.StoreSubmission(
listing_id=(submission.StoreListing.id if submission.StoreListing else ""), listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
@@ -1958,3 +1988,552 @@ async def get_agent_as_admin(
) )
return graph return graph
def _waitlist_to_store_entry(
waitlist: prisma.models.WaitlistEntry,
) -> store_model.StoreWaitlistEntry:
"""Convert a WaitlistEntry to StoreWaitlistEntry for public display."""
return store_model.StoreWaitlistEntry(
waitlistId=waitlist.id,
slug=waitlist.slug,
name=waitlist.name,
subHeading=waitlist.subHeading,
videoUrl=waitlist.videoUrl,
agentOutputDemoUrl=waitlist.agentOutputDemoUrl,
imageUrls=waitlist.imageUrls or [],
description=waitlist.description,
categories=waitlist.categories,
)
async def get_waitlist() -> list[store_model.StoreWaitlistEntry]:
"""Get all active waitlists for public display."""
try:
waitlists = await prisma.models.WaitlistEntry.prisma().find_many(
where=prisma.types.WaitlistEntryWhereInput(isDeleted=False),
)
# Filter out closed/done waitlists and sort by votes (descending)
excluded_statuses = {
prisma.enums.WaitlistExternalStatus.CANCELED,
prisma.enums.WaitlistExternalStatus.DONE,
}
active_waitlists = [w for w in waitlists if w.status not in excluded_statuses]
sorted_list = sorted(active_waitlists, key=lambda x: x.votes, reverse=True)
return [_waitlist_to_store_entry(w) for w in sorted_list]
except Exception as e:
logger.error(f"Error fetching waitlists: {e}")
raise DatabaseError("Failed to fetch waitlists") from e
async def get_user_waitlist_memberships(user_id: str) -> list[str]:
"""Get all waitlist IDs that a user has joined."""
try:
user = await prisma.models.User.prisma().find_unique(
where={"id": user_id},
include={"joinedWaitlists": True},
)
if not user or not user.joinedWaitlists:
return []
return [w.id for w in user.joinedWaitlists]
except Exception as e:
logger.error(f"Error fetching user waitlist memberships: {e}")
raise DatabaseError("Failed to fetch waitlist memberships") from e
async def add_user_to_waitlist(
waitlist_id: str, user_id: str | None, email: str | None
) -> store_model.StoreWaitlistEntry:
"""
Add a user to a waitlist.
For logged-in users: connects via joinedUsers relation
For anonymous users: adds email to unaffiliatedEmailUsers array
"""
if not user_id and not email:
raise ValueError("Either user_id or email must be provided")
try:
# Find the waitlist
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
include={"joinedUsers": True},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} is no longer available")
if waitlist.status in [
prisma.enums.WaitlistExternalStatus.CANCELED,
prisma.enums.WaitlistExternalStatus.DONE,
]:
raise ValueError(f"Waitlist {waitlist_id} is closed")
if user_id:
# Check if user already joined
joined_user_ids = [u.id for u in (waitlist.joinedUsers or [])]
if user_id in joined_user_ids:
# Already joined - return waitlist info
logger.debug(f"User {user_id} already joined waitlist {waitlist_id}")
else:
# Connect user to waitlist
await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data={"joinedUsers": {"connect": [{"id": user_id}]}},
)
logger.info(f"User {user_id} joined waitlist {waitlist_id}")
# If user was previously in email list, remove them
# Use transaction to prevent race conditions
if email:
async with transaction() as tx:
current_waitlist = await tx.waitlistentry.find_unique(
where={"id": waitlist_id}
)
if current_waitlist and email in (
current_waitlist.unaffiliatedEmailUsers or []
):
updated_emails: list[str] = [
e
for e in (current_waitlist.unaffiliatedEmailUsers or [])
if e != email
]
await tx.waitlistentry.update(
where={"id": waitlist_id},
data={"unaffiliatedEmailUsers": updated_emails},
)
elif email:
# Add email to unaffiliated list if not already present
# Use transaction to prevent race conditions with concurrent signups
async with transaction() as tx:
# Re-fetch within transaction to get latest state
current_waitlist = await tx.waitlistentry.find_unique(
where={"id": waitlist_id}
)
if current_waitlist:
current_emails: list[str] = list(
current_waitlist.unaffiliatedEmailUsers or []
)
if email not in current_emails:
current_emails.append(email)
await tx.waitlistentry.update(
where={"id": waitlist_id},
data={"unaffiliatedEmailUsers": current_emails},
)
logger.info(f"Email {email} added to waitlist {waitlist_id}")
else:
logger.debug(f"Email {email} already on waitlist {waitlist_id}")
# Re-fetch to return updated data
updated_waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id}
)
return _waitlist_to_store_entry(updated_waitlist or waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error adding user to waitlist: {e}")
raise DatabaseError("Failed to add user to waitlist") from e
# ============== Admin Waitlist Functions ==============
def _waitlist_to_admin_response(
waitlist: prisma.models.WaitlistEntry,
) -> store_model.WaitlistAdminResponse:
"""Convert a WaitlistEntry to WaitlistAdminResponse."""
joined_count = len(waitlist.joinedUsers) if waitlist.joinedUsers else 0
email_count = (
len(waitlist.unaffiliatedEmailUsers) if waitlist.unaffiliatedEmailUsers else 0
)
return store_model.WaitlistAdminResponse(
id=waitlist.id,
createdAt=waitlist.createdAt.isoformat() if waitlist.createdAt else "",
updatedAt=waitlist.updatedAt.isoformat() if waitlist.updatedAt else "",
slug=waitlist.slug,
name=waitlist.name,
subHeading=waitlist.subHeading,
description=waitlist.description,
categories=waitlist.categories,
imageUrls=waitlist.imageUrls or [],
videoUrl=waitlist.videoUrl,
agentOutputDemoUrl=waitlist.agentOutputDemoUrl,
status=waitlist.status or prisma.enums.WaitlistExternalStatus.NOT_STARTED,
votes=waitlist.votes,
signupCount=joined_count + email_count,
storeListingId=waitlist.storeListingId,
owningUserId=waitlist.owningUserId,
)
async def create_waitlist_admin(
admin_user_id: str,
data: store_model.WaitlistCreateRequest,
) -> store_model.WaitlistAdminResponse:
"""Create a new waitlist (admin only)."""
logger.info(f"Admin {admin_user_id} creating waitlist: {data.name}")
try:
waitlist = await prisma.models.WaitlistEntry.prisma().create(
data=prisma.types.WaitlistEntryCreateInput(
name=data.name,
slug=data.slug,
subHeading=data.subHeading,
description=data.description,
categories=data.categories,
imageUrls=data.imageUrls,
videoUrl=data.videoUrl,
agentOutputDemoUrl=data.agentOutputDemoUrl,
owningUserId=admin_user_id,
status=prisma.enums.WaitlistExternalStatus.NOT_STARTED,
),
include={"joinedUsers": True},
)
return _waitlist_to_admin_response(waitlist)
except Exception as e:
logger.error(f"Error creating waitlist: {e}")
raise DatabaseError("Failed to create waitlist") from e
async def get_waitlists_admin() -> store_model.WaitlistAdminListResponse:
"""Get all waitlists with admin details."""
try:
waitlists = await prisma.models.WaitlistEntry.prisma().find_many(
where=prisma.types.WaitlistEntryWhereInput(isDeleted=False),
include={"joinedUsers": True},
order={"createdAt": "desc"},
)
return store_model.WaitlistAdminListResponse(
waitlists=[_waitlist_to_admin_response(w) for w in waitlists],
totalCount=len(waitlists),
)
except Exception as e:
logger.error(f"Error fetching waitlists for admin: {e}")
raise DatabaseError("Failed to fetch waitlists") from e
async def get_waitlist_admin(
waitlist_id: str,
) -> store_model.WaitlistAdminResponse:
"""Get a single waitlist with admin details."""
try:
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
include={"joinedUsers": True},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has been deleted")
return _waitlist_to_admin_response(waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error fetching waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to fetch waitlist") from e
async def update_waitlist_admin(
waitlist_id: str,
data: store_model.WaitlistUpdateRequest,
) -> store_model.WaitlistAdminResponse:
"""Update a waitlist (admin only)."""
logger.info(f"Updating waitlist {waitlist_id}")
try:
# Check if waitlist exists first
existing = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id}
)
if not existing:
raise ValueError(f"Waitlist {waitlist_id} not found")
if existing.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has been deleted")
# Build update data from explicitly provided fields
# Use model_fields_set to allow clearing fields by setting them to None
field_mappings = {
"name": data.name,
"slug": data.slug,
"subHeading": data.subHeading,
"description": data.description,
"categories": data.categories,
"imageUrls": data.imageUrls,
"videoUrl": data.videoUrl,
"agentOutputDemoUrl": data.agentOutputDemoUrl,
"storeListingId": data.storeListingId,
}
update_data: dict[str, Any] = {
k: v for k, v in field_mappings.items() if k in data.model_fields_set
}
# Add status if provided (already validated as enum by Pydantic)
if "status" in data.model_fields_set and data.status is not None:
update_data["status"] = data.status
if not update_data:
# No updates, just return current data
return await get_waitlist_admin(waitlist_id)
waitlist = await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data=prisma.types.WaitlistEntryUpdateInput(**update_data),
include={"joinedUsers": True},
)
# We already verified existence above, so this should never be None
assert waitlist is not None
return _waitlist_to_admin_response(waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error updating waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to update waitlist") from e
async def delete_waitlist_admin(waitlist_id: str) -> None:
"""Soft delete a waitlist (admin only)."""
logger.info(f"Soft deleting waitlist {waitlist_id}")
try:
# Check if waitlist exists first
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has already been deleted")
await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data={"isDeleted": True},
)
except ValueError:
raise
except Exception as e:
logger.error(f"Error deleting waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to delete waitlist") from e
async def get_waitlist_signups_admin(
waitlist_id: str,
) -> store_model.WaitlistSignupListResponse:
"""Get all signups for a waitlist (admin only)."""
try:
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
include={"joinedUsers": True},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
signups: list[store_model.WaitlistSignup] = []
# Add user signups
for user in waitlist.joinedUsers or []:
signups.append(
store_model.WaitlistSignup(
type="user",
userId=user.id,
email=user.email,
username=user.name,
)
)
# Add email signups
for email in waitlist.unaffiliatedEmailUsers or []:
signups.append(
store_model.WaitlistSignup(
type="email",
email=email,
)
)
return store_model.WaitlistSignupListResponse(
waitlistId=waitlist_id,
signups=signups,
totalCount=len(signups),
)
except ValueError:
raise
except Exception as e:
logger.error(f"Error fetching signups for waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to fetch waitlist signups") from e
async def link_waitlist_to_listing_admin(
waitlist_id: str,
store_listing_id: str,
) -> store_model.WaitlistAdminResponse:
"""Link a waitlist to a store listing (admin only)."""
logger.info(f"Linking waitlist {waitlist_id} to listing {store_listing_id}")
try:
# Verify the waitlist exists
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id}
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has been deleted")
# Verify the store listing exists
listing = await prisma.models.StoreListing.prisma().find_unique(
where={"id": store_listing_id}
)
if not listing:
raise ValueError(f"Store listing {store_listing_id} not found")
updated_waitlist = await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data={"StoreListing": {"connect": {"id": store_listing_id}}},
include={"joinedUsers": True},
)
# We already verified existence above, so this should never be None
assert updated_waitlist is not None
return _waitlist_to_admin_response(updated_waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error linking waitlist to listing: {e}")
raise DatabaseError("Failed to link waitlist to listing") from e
async def notify_waitlist_users_on_launch(
store_listing_id: str,
agent_name: str,
store_url: str,
) -> int:
"""
Notify all users on waitlists linked to a store listing when the agent is launched.
Args:
store_listing_id: The ID of the store listing that was approved
agent_name: The name of the approved agent
store_url: The URL to the agent's store page
Returns:
The number of notifications sent
"""
logger.info(f"Notifying waitlist users for store listing {store_listing_id}")
try:
# Find all active waitlists linked to this store listing
# Exclude DONE and CANCELED to prevent duplicate notifications on re-approval
waitlists = await prisma.models.WaitlistEntry.prisma().find_many(
where={
"storeListingId": store_listing_id,
"isDeleted": False,
"status": {
"not_in": [
prisma.enums.WaitlistExternalStatus.DONE,
prisma.enums.WaitlistExternalStatus.CANCELED,
]
},
},
include={"joinedUsers": True},
)
if not waitlists:
logger.info(
f"No active waitlists found for store listing {store_listing_id}"
)
return 0
notification_count = 0
launched_at = datetime.now(tz=timezone.utc)
for waitlist in waitlists:
# Track notification results for this waitlist
users_to_notify = waitlist.joinedUsers or []
failed_user_ids: list[str] = []
# Notify registered users
for user in users_to_notify:
try:
notification_data = WaitlistLaunchData(
agent_name=agent_name,
waitlist_name=waitlist.name,
store_url=store_url,
launched_at=launched_at,
)
notification_event = NotificationEventModel[WaitlistLaunchData](
user_id=user.id,
type=prisma.enums.NotificationType.WAITLIST_LAUNCH,
data=notification_data,
)
await queue_notification_async(notification_event)
notification_count += 1
except Exception as e:
logger.error(
f"Failed to send waitlist launch notification to user {user.id}: {e}"
)
failed_user_ids.append(user.id)
# Note: For unaffiliated email users, you would need to send emails directly
# since they don't have user IDs for the notification system.
# This could be done via a separate email service.
# For now, we log these for potential manual follow-up or future implementation.
has_pending_email_users = bool(waitlist.unaffiliatedEmailUsers)
if has_pending_email_users:
logger.info(
f"Waitlist {waitlist.id} has {len(waitlist.unaffiliatedEmailUsers)} "
f"unaffiliated email users that need email notifications"
)
# Only mark waitlist as DONE if all registered user notifications succeeded
# AND there are no unaffiliated email users still waiting for notifications
if not failed_user_ids and not has_pending_email_users:
await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist.id},
data={"status": prisma.enums.WaitlistExternalStatus.DONE},
)
logger.info(f"Updated waitlist {waitlist.id} status to DONE")
elif failed_user_ids:
logger.warning(
f"Waitlist {waitlist.id} not marked as DONE due to "
f"{len(failed_user_ids)} failed notifications"
)
elif has_pending_email_users:
logger.warning(
f"Waitlist {waitlist.id} not marked as DONE due to "
f"{len(waitlist.unaffiliatedEmailUsers)} pending email-only users"
)
logger.info(
f"Sent {notification_count} waitlist launch notifications for store listing {store_listing_id}"
)
return notification_count
except Exception as e:
logger.error(
f"Error notifying waitlist users for store listing {store_listing_id}: {e}"
)
# Don't raise - we don't want to fail the approval process
return 0

View File

@@ -21,6 +21,7 @@ from backend.util.json import dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# OpenAI embedding model configuration # OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small" EMBEDDING_MODEL = "text-embedding-3-small"
# Embedding dimension for the model above # Embedding dimension for the model above
@@ -62,42 +63,49 @@ def build_searchable_text(
return " ".join(parts) return " ".join(parts)
async def generate_embedding(text: str) -> list[float]: async def generate_embedding(text: str) -> list[float] | None:
""" """
Generate embedding for text using OpenAI API. Generate embedding for text using OpenAI API.
Raises exceptions on failure - caller should handle. Returns None if embedding generation fails.
Fail-fast: no retries to maintain consistency with approval flow.
""" """
client = get_openai_client() try:
if not client: client = get_openai_client()
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding") if not client:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
# Truncate text to token limit using tiktoken # Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type # Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL) enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text) tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS: if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS] tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens) truncated_text = enc.decode(tokens)
logger.info( logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens" f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
) )
else: latency_ms = (time.time() - start_time) * 1000
truncated_text = text
start_time = time.time() embedding = response.data[0].embedding
response = await client.embeddings.create( logger.info(
model=EMBEDDING_MODEL, f"Generated embedding: {len(embedding)} dims, "
input=truncated_text, f"{len(tokens)} tokens, {latency_ms:.0f}ms"
) )
latency_ms = (time.time() - start_time) * 1000 return embedding
embedding = response.data[0].embedding except Exception as e:
logger.info( logger.error(f"Failed to generate embedding: {e}")
f"Generated embedding: {len(embedding)} dims, " return None
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
async def store_embedding( async def store_embedding(
@@ -136,45 +144,48 @@ async def store_content_embedding(
New function for unified content embedding storage. New function for unified content embedding storage.
Uses raw SQL since Prisma doesn't natively support pgvector. Uses raw SQL since Prisma doesn't natively support pgvector.
Raises exceptions on failure - caller should handle.
""" """
client = tx if tx else prisma.get_client() try:
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format # Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding) embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {}) metadata_json = dumps(metadata or {})
# Upsert the embedding # Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT # WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
# Use unqualified ::vector - pgvector is in search_path on all environments # Use {pgvector_schema}.vector for explicit pgvector type qualification
await execute_raw_with_schema( await execute_raw_with_schema(
""" """
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" ( INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt" "id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::{pgvector_schema}.vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::{pgvector_schema}.vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
) )
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
)
logger.info(f"Stored embedding for {content_type}:{content_id}") logger.info(f"Stored embedding for {content_type}:{content_id}")
return True return True
except Exception as e:
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding(version_id: str) -> dict[str, Any] | None: async def get_embedding(version_id: str) -> dict[str, Any] | None:
@@ -206,31 +217,34 @@ async def get_content_embedding(
New function for unified content embedding retrieval. New function for unified content embedding retrieval.
Returns dict with contentType, contentId, embedding, timestamps or None if not found. Returns dict with contentType, contentId, embedding, timestamps or None if not found.
Raises exceptions on failure - caller should handle.
""" """
result = await query_raw_with_schema( try:
""" result = await query_raw_with_schema(
SELECT """
"contentType", SELECT
"contentId", "contentType",
"userId", "contentId",
"embedding"::text as "embedding", "userId",
"searchableText", "embedding"::text as "embedding",
"metadata", "searchableText",
"createdAt", "metadata",
"updatedAt" "createdAt",
FROM {schema_prefix}"UnifiedContentEmbedding" "updatedAt"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL)) FROM {schema_prefix}"UnifiedContentEmbedding"
""", WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
content_type, """,
content_id, content_type,
user_id, content_id,
) user_id,
)
if result and len(result) > 0: if result and len(result) > 0:
return result[0] return result[0]
return None return None
except Exception as e:
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
return None
async def ensure_embedding( async def ensure_embedding(
@@ -258,38 +272,46 @@ async def ensure_embedding(
tx: Optional transaction client tx: Optional transaction client
Returns: Returns:
True if embedding exists/was created True if embedding exists/was created, False on failure
Raises exceptions on failure - caller should handle.
""" """
# Check if embedding already exists try:
if not force: # Check if embedding already exists
existing = await get_embedding(version_id) if not force:
if existing and existing.get("embedding"): existing = await get_embedding(version_id)
logger.debug(f"Embedding for version {version_id} already exists") if existing and existing.get("embedding"):
return True logger.debug(f"Embedding for version {version_id} already exists")
return True
# Build searchable text for embedding # Build searchable text for embedding
searchable_text = build_searchable_text(name, description, sub_heading, categories) searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
# Generate new embedding # Generate new embedding
embedding = await generate_embedding(searchable_text) embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Store the embedding with metadata using new function # Store the embedding with metadata using new function
metadata = { metadata = {
"name": name, "name": name,
"subHeading": sub_heading, "subHeading": sub_heading,
"categories": categories, "categories": categories,
} }
return await store_content_embedding( return await store_content_embedding(
content_type=ContentType.STORE_AGENT, content_type=ContentType.STORE_AGENT,
content_id=version_id, content_id=version_id,
embedding=embedding, embedding=embedding,
searchable_text=searchable_text, searchable_text=searchable_text,
metadata=metadata, metadata=metadata,
user_id=None, # Store agents are public user_id=None, # Store agents are public
tx=tx, tx=tx,
) )
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
async def delete_embedding(version_id: str) -> bool: async def delete_embedding(version_id: str) -> bool:
@@ -499,24 +521,6 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
success = sum(1 for result in results if result is True) success = sum(1 for result in results if result is True)
failed = len(results) - success failed = len(results) - success
# Aggregate unique errors to avoid Sentry spam
if failed > 0:
# Group errors by type and message
error_summary: dict[str, int] = {}
for result in results:
if isinstance(result, Exception):
error_key = f"{type(result).__name__}: {str(result)}"
error_summary[error_key] = error_summary.get(error_key, 0) + 1
# Log aggregated error summary
error_details = ", ".join(
f"{error} ({count}x)" for error, count in error_summary.items()
)
logger.error(
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
f"Errors: {error_details}"
)
results_by_type[content_type.value] = { results_by_type[content_type.value] = {
"processed": len(missing_items), "processed": len(missing_items),
"success": success, "success": success,
@@ -553,12 +557,11 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
} }
async def embed_query(query: str) -> list[float]: async def embed_query(query: str) -> list[float] | None:
""" """
Generate embedding for a search query. Generate embedding for a search query.
Same as generate_embedding but with clearer intent. Same as generate_embedding but with clearer intent.
Raises exceptions on failure - caller should handle.
""" """
return await generate_embedding(query) return await generate_embedding(query)
@@ -591,30 +594,40 @@ async def ensure_content_embedding(
tx: Optional transaction client tx: Optional transaction client
Returns: Returns:
True if embedding exists/was created True if embedding exists/was created, False on failure
Raises exceptions on failure - caller should handle.
""" """
# Check if embedding already exists try:
if not force: # Check if embedding already exists
existing = await get_content_embedding(content_type, content_id, user_id) if not force:
if existing and existing.get("embedding"): existing = await get_content_embedding(content_type, content_id, user_id)
logger.debug(f"Embedding for {content_type}:{content_id} already exists") if existing and existing.get("embedding"):
return True logger.debug(
f"Embedding for {content_type}:{content_id} already exists"
)
return True
# Generate new embedding # Generate new embedding
embedding = await generate_embedding(searchable_text) embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(
f"Could not generate embedding for {content_type}:{content_id}"
)
return False
# Store the embedding # Store the embedding
return await store_content_embedding( return await store_content_embedding(
content_type=content_type, content_type=content_type,
content_id=content_id, content_id=content_id,
embedding=embedding, embedding=embedding,
searchable_text=searchable_text, searchable_text=searchable_text,
metadata=metadata or {}, metadata=metadata or {},
user_id=user_id, user_id=user_id,
tx=tx, tx=tx,
) )
except Exception as e:
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
return False
async def cleanup_orphaned_embeddings() -> dict[str, Any]: async def cleanup_orphaned_embeddings() -> dict[str, Any]:
@@ -841,8 +854,9 @@ async def semantic_search(
limit = 100 limit = 100
# Generate query embedding # Generate query embedding
try: query_embedding = await embed_query(query)
query_embedding = await embed_query(query)
if query_embedding is not None:
# Semantic search with embeddings # Semantic search with embeddings
embedding_str = embedding_to_vector_string(query_embedding) embedding_str = embedding_to_vector_string(query_embedding)
@@ -865,7 +879,8 @@ async def semantic_search(
min_similarity_idx = len(params) + 1 min_similarity_idx = len(params) + 1
params.append(min_similarity) params.append(min_similarity)
# Use unqualified ::vector and <=> operator - pgvector is in search_path on all environments # Use regular string (not f-string) for template to preserve {schema_prefix} and {schema} placeholders
# Use OPERATOR({pgvector_schema}.<=>) for explicit operator schema qualification
sql = ( sql = (
""" """
SELECT SELECT
@@ -873,9 +888,9 @@ async def semantic_search(
"contentType" as content_type, "contentType" as content_type,
"searchableText" as searchable_text, "searchableText" as searchable_text,
metadata, metadata,
1 - (embedding <=> '""" 1 - (embedding OPERATOR({pgvector_schema}.<=>) '"""
+ embedding_str + embedding_str
+ """'::vector) as similarity + """'::{pgvector_schema}.vector) as similarity
FROM {schema_prefix}"UnifiedContentEmbedding" FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" IN (""" WHERE "contentType" IN ("""
+ content_type_placeholders + content_type_placeholders
@@ -883,9 +898,9 @@ async def semantic_search(
""" """
+ user_filter + user_filter
+ """ + """
AND 1 - (embedding <=> '""" AND 1 - (embedding OPERATOR({pgvector_schema}.<=>) '"""
+ embedding_str + embedding_str
+ """'::vector) >= $""" + """'::{pgvector_schema}.vector) >= $"""
+ str(min_similarity_idx) + str(min_similarity_idx)
+ """ + """
ORDER BY similarity DESC ORDER BY similarity DESC
@@ -893,21 +908,24 @@ async def semantic_search(
""" """
) )
results = await query_raw_with_schema(sql, *params) try:
return [ results = await query_raw_with_schema(sql, *params)
{ return [
"content_id": row["content_id"], {
"content_type": row["content_type"], "content_id": row["content_id"],
"searchable_text": row["searchable_text"], "content_type": row["content_type"],
"metadata": row["metadata"], "searchable_text": row["searchable_text"],
"similarity": float(row["similarity"]), "metadata": row["metadata"],
} "similarity": float(row["similarity"]),
for row in results }
] for row in results
except Exception as e: ]
logger.warning(f"Semantic search failed, falling back to lexical search: {e}") except Exception as e:
logger.error(f"Semantic search failed: {e}")
# Fall through to lexical search below
# Fallback to lexical search if embeddings unavailable # Fallback to lexical search if embeddings unavailable
logger.warning("Falling back to lexical search (embeddings unavailable)")
params_lexical: list[Any] = [limit] params_lexical: list[Any] = [limit]
user_filter = "" user_filter = ""

View File

@@ -298,16 +298,17 @@ async def test_schema_handling_error_cases():
mock_client.execute_raw.side_effect = Exception("Database error") mock_client.execute_raw.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client mock_get_client.return_value = mock_client
# Should raise exception on error result = await embeddings.store_content_embedding(
with pytest.raises(Exception, match="Database error"): content_type=ContentType.STORE_AGENT,
await embeddings.store_content_embedding( content_id="test-id",
content_type=ContentType.STORE_AGENT, embedding=[0.1] * EMBEDDING_DIM,
content_id="test-id", searchable_text="test",
embedding=[0.1] * EMBEDDING_DIM, metadata=None,
searchable_text="test", user_id=None,
metadata=None, )
user_id=None,
) # Should return False on error, not raise
assert result is False
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -80,8 +80,9 @@ async def test_generate_embedding_no_api_key():
) as mock_get_client: ) as mock_get_client:
mock_get_client.return_value = None mock_get_client.return_value = None
with pytest.raises(RuntimeError, match="openai_internal_api_key not set"): result = await embeddings.generate_embedding("test text")
await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
@@ -96,8 +97,9 @@ async def test_generate_embedding_api_error():
) as mock_get_client: ) as mock_get_client:
mock_get_client.return_value = mock_client mock_get_client.return_value = mock_client
with pytest.raises(Exception, match="API Error"): result = await embeddings.generate_embedding("test text")
await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
@@ -171,10 +173,11 @@ async def test_store_embedding_database_error(mocker):
embedding = [0.1, 0.2, 0.3] embedding = [0.1, 0.2, 0.3]
with pytest.raises(Exception, match="Database error"): result = await embeddings.store_embedding(
await embeddings.store_embedding( version_id="test-version-id", embedding=embedding, tx=mock_client
version_id="test-version-id", embedding=embedding, tx=mock_client )
)
assert result is False
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
@@ -274,16 +277,17 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
async def test_ensure_embedding_generation_fails(mock_get, mock_generate): async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
"""Test ensure_embedding when generation fails.""" """Test ensure_embedding when generation fails."""
mock_get.return_value = None mock_get.return_value = None
mock_generate.side_effect = Exception("Generation failed") mock_generate.return_value = None
with pytest.raises(Exception, match="Generation failed"): result = await embeddings.ensure_embedding(
await embeddings.ensure_embedding( version_id="test-id",
version_id="test-id", name="Test",
name="Test", description="Test description",
description="Test description", sub_heading="Test heading",
sub_heading="Test heading", categories=["test"],
categories=["test"], )
)
assert result is False
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")

View File

@@ -186,12 +186,13 @@ async def unified_hybrid_search(
offset = (page - 1) * page_size offset = (page - 1) * page_size
# Generate query embedding with graceful degradation # Generate query embedding
try: query_embedding = await embed_query(query)
query_embedding = await embed_query(query)
except Exception as e: # Graceful degradation if embedding unavailable
if query_embedding is None or not query_embedding:
logger.warning( logger.warning(
f"Failed to generate query embedding - falling back to lexical-only search: {e}. " "Failed to generate query embedding - falling back to lexical-only search. "
"Check that openai_internal_api_key is configured and OpenAI API is accessible." "Check that openai_internal_api_key is configured and OpenAI API is accessible."
) )
query_embedding = [0.0] * EMBEDDING_DIM query_embedding = [0.0] * EMBEDDING_DIM
@@ -294,7 +295,7 @@ async def unified_hybrid_search(
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[]) WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
{user_filter} {user_filter}
ORDER BY uce.embedding <=> {embedding_param}::vector ORDER BY uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector
LIMIT 200 LIMIT 200
) )
), ),
@@ -306,7 +307,7 @@ async def unified_hybrid_search(
uce.metadata, uce.metadata,
uce."updatedAt" as updated_at, uce."updatedAt" as updated_at,
-- Semantic score: cosine similarity (1 - distance) -- Semantic score: cosine similarity (1 - distance)
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score, COALESCE(1 - (uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector), 0) as semantic_score,
-- Lexical score: ts_rank_cd -- Lexical score: ts_rank_cd
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw, COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match from metadata -- Category match from metadata
@@ -463,12 +464,13 @@ async def hybrid_search(
offset = (page - 1) * page_size offset = (page - 1) * page_size
# Generate query embedding with graceful degradation # Generate query embedding
try: query_embedding = await embed_query(query)
query_embedding = await embed_query(query)
except Exception as e: # Graceful degradation
if query_embedding is None or not query_embedding:
logger.warning( logger.warning(
f"Failed to generate query embedding - falling back to lexical-only search: {e}" "Failed to generate query embedding - falling back to lexical-only search."
) )
query_embedding = [0.0] * EMBEDDING_DIM query_embedding = [0.0] * EMBEDDING_DIM
total_non_semantic = ( total_non_semantic = (
@@ -581,7 +583,7 @@ async def hybrid_search(
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType" WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL AND uce."userId" IS NULL
AND {where_clause} AND {where_clause}
ORDER BY uce.embedding <=> {embedding_param}::vector ORDER BY uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector
LIMIT 200 LIMIT 200
) uce ) uce
), ),
@@ -603,7 +605,7 @@ async def hybrid_search(
-- Searchable text for BM25 reranking -- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text, COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score -- Semantic score
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score, COALESCE(1 - (uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector), 0) as semantic_score,
-- Lexical score (raw, will normalize) -- Lexical score (raw, will normalize)
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw, COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match -- Category match

View File

@@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings():
with patch( with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema" "backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query: ) as mock_query:
# Simulate embedding failure by raising exception # Simulate embedding failure
mock_embed.side_effect = Exception("Embedding generation failed") mock_embed.return_value = None
mock_query.return_value = mock_results mock_query.return_value = mock_results
# Should NOT raise - graceful degradation # Should NOT raise - graceful degradation
@@ -613,9 +613,7 @@ async def test_unified_hybrid_search_graceful_degradation():
"backend.api.features.store.hybrid_search.embed_query" "backend.api.features.store.hybrid_search.embed_query"
) as mock_embed: ) as mock_embed:
mock_query.return_value = mock_results mock_query.return_value = mock_results
mock_embed.side_effect = Exception( mock_embed.return_value = None # Embedding failure
"Embedding generation failed"
) # Embedding failure
# Should NOT raise - graceful degradation # Should NOT raise - graceful degradation
results, total = await unified_hybrid_search( results, total = await unified_hybrid_search(

View File

@@ -223,6 +223,102 @@ class ReviewSubmissionRequest(pydantic.BaseModel):
internal_comments: str | None = None # Private admin notes internal_comments: str | None = None # Private admin notes
class StoreWaitlistEntry(pydantic.BaseModel):
"""Public waitlist entry - no PII fields exposed."""
waitlistId: str
slug: str
# Content fields
name: str
subHeading: str
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
imageUrls: list[str]
description: str
categories: list[str]
class StoreWaitlistsAllResponse(pydantic.BaseModel):
listings: list[StoreWaitlistEntry]
# Admin Waitlist Models
class WaitlistCreateRequest(pydantic.BaseModel):
"""Request model for creating a new waitlist."""
name: str
slug: str
subHeading: str
description: str
categories: list[str] = []
imageUrls: list[str] = []
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
class WaitlistUpdateRequest(pydantic.BaseModel):
"""Request model for updating a waitlist."""
name: str | None = None
slug: str | None = None
subHeading: str | None = None
description: str | None = None
categories: list[str] | None = None
imageUrls: list[str] | None = None
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
status: prisma.enums.WaitlistExternalStatus | None = None
storeListingId: str | None = None # Link to a store listing
class WaitlistAdminResponse(pydantic.BaseModel):
"""Admin response model with full waitlist details including internal data."""
id: str
createdAt: str
updatedAt: str
slug: str
name: str
subHeading: str
description: str
categories: list[str]
imageUrls: list[str]
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
status: prisma.enums.WaitlistExternalStatus
votes: int
signupCount: int # Total count of joinedUsers + unaffiliatedEmailUsers
storeListingId: str | None = None
owningUserId: str
class WaitlistSignup(pydantic.BaseModel):
"""Individual signup entry for a waitlist."""
type: str # "user" or "email"
userId: str | None = None
email: str | None = None
username: str | None = None # For user signups
class WaitlistSignupListResponse(pydantic.BaseModel):
"""Response model for listing waitlist signups."""
waitlistId: str
signups: list[WaitlistSignup]
totalCount: int
class WaitlistAdminListResponse(pydantic.BaseModel):
"""Response model for listing all waitlists (admin view)."""
waitlists: list[WaitlistAdminResponse]
totalCount: int
class UnifiedSearchResult(pydantic.BaseModel): class UnifiedSearchResult(pydantic.BaseModel):
"""A single result from unified hybrid search across all content types.""" """A single result from unified hybrid search across all content types."""

View File

@@ -8,6 +8,7 @@ import autogpt_libs.auth
import fastapi import fastapi
import fastapi.responses import fastapi.responses
import prisma.enums import prisma.enums
from autogpt_libs.auth.dependencies import get_optional_user_id
import backend.data.graph import backend.data.graph
import backend.util.json import backend.util.json
@@ -81,6 +82,74 @@ async def update_or_create_profile(
return updated_profile return updated_profile
##############################################
############## Waitlist Endpoints ############
##############################################
@router.get(
"/waitlist",
summary="Get the agent waitlist",
tags=["store", "public"],
response_model=store_model.StoreWaitlistsAllResponse,
)
async def get_waitlist():
"""
Get all active waitlists for public display.
"""
waitlists = await store_db.get_waitlist()
return store_model.StoreWaitlistsAllResponse(listings=waitlists)
@router.get(
"/waitlist/my-memberships",
summary="Get waitlist IDs the current user has joined",
tags=["store", "private"],
)
async def get_my_waitlist_memberships(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> list[str]:
"""Returns list of waitlist IDs the authenticated user has joined."""
return await store_db.get_user_waitlist_memberships(user_id)
@router.post(
path="/waitlist/{waitlist_id}/join",
summary="Add self to the agent waitlist",
tags=["store", "public"],
response_model=store_model.StoreWaitlistEntry,
)
async def add_self_to_waitlist(
user_id: str | None = fastapi.Security(get_optional_user_id),
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist to join"),
email: str | None = fastapi.Body(
default=None, embed=True, description="Email address for unauthenticated users"
),
):
"""
Add the current user to the agent waitlist.
"""
if not user_id and not email:
raise fastapi.HTTPException(
status_code=400,
detail="Either user authentication or email address is required",
)
try:
waitlist_entry = await store_db.add_user_to_waitlist(
waitlist_id=waitlist_id, user_id=user_id, email=email
)
return waitlist_entry
except ValueError as e:
error_msg = str(e)
if "not found" in error_msg:
raise fastapi.HTTPException(status_code=404, detail="Waitlist not found")
# Waitlist exists but is closed or unavailable
raise fastapi.HTTPException(status_code=400, detail=error_msg)
except Exception:
raise fastapi.HTTPException(
status_code=500, detail="An error occurred while joining the waitlist"
)
############################################## ##############################################
############### Agent Endpoints ############## ############### Agent Endpoints ##############
############################################## ##############################################

View File

@@ -364,8 +364,6 @@ async def execute_graph_block(
obj = get_block(block_id) obj = get_block(block_id)
if not obj: if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled:
raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.")
user = await get_user_by_id(user_id) user = await get_user_by_id(user_id)
if not user: if not user:
@@ -763,8 +761,10 @@ async def create_new_graph(
graph.reassign_ids(user_id=user_id, reassign_graph_id=True) graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False) graph.validate_graph(for_run=False)
# The return value of the create graph & library function is intentionally not used here,
# as the graph already valid and no sub-graphs are returned back.
await graph_db.create_graph(graph, user_id=user_id) await graph_db.create_graph(graph, user_id=user_id)
await library_db.create_library_agent(graph, user_id) await library_db.create_library_agent(graph, user_id=user_id)
activated_graph = await on_graph_activate(graph, user_id=user_id) activated_graph = await on_graph_activate(graph, user_id=user_id)
if create_graph.source == "builder": if create_graph.source == "builder":
@@ -888,19 +888,21 @@ async def set_graph_active_version(
async def _update_library_agent_version_and_settings( async def _update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent: ) -> library_model.LibraryAgent:
# Keep the library agent up to date with the new active version
library = await library_db.update_agent_version_in_library( library = await library_db.update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version user_id, agent_graph.id, agent_graph.version
) )
updated_settings = GraphSettings.from_graph( # If the graph has HITL node, initialize the setting if it's not already set.
graph=agent_graph, if (
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode, agent_graph.has_human_in_the_loop
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode, and library.settings.human_in_the_loop_safe_mode is None
) ):
if updated_settings != library.settings: await library_db.update_library_agent_settings(
library = await library_db.update_library_agent(
library_agent_id=library.id,
user_id=user_id, user_id=user_id,
settings=updated_settings, agent_id=library.id,
settings=library.settings.model_copy(
update={"human_in_the_loop_safe_mode": True}
),
) )
return library return library
@@ -917,18 +919,21 @@ async def update_graph_settings(
user_id: Annotated[str, Security(get_user_id)], user_id: Annotated[str, Security(get_user_id)],
) -> GraphSettings: ) -> GraphSettings:
"""Update graph settings for the user's library agent.""" """Update graph settings for the user's library agent."""
# Get the library agent for this graph
library_agent = await library_db.get_library_agent_by_graph_id( library_agent = await library_db.get_library_agent_by_graph_id(
graph_id=graph_id, user_id=user_id graph_id=graph_id, user_id=user_id
) )
if not library_agent: if not library_agent:
raise HTTPException(404, f"Graph #{graph_id} not found in user's library") raise HTTPException(404, f"Graph #{graph_id} not found in user's library")
updated_agent = await library_db.update_library_agent( # Update the library agent settings
library_agent_id=library_agent.id, updated_agent = await library_db.update_library_agent_settings(
user_id=user_id, user_id=user_id,
agent_id=library_agent.id,
settings=settings, settings=settings,
) )
# Return the updated settings
return GraphSettings.model_validate(updated_agent.settings) return GraphSettings.model_validate(updated_agent.settings)

View File

@@ -138,7 +138,6 @@ def test_execute_graph_block(
"""Test execute block endpoint""" """Test execute block endpoint"""
# Mock block # Mock block
mock_block = Mock() mock_block = Mock()
mock_block.disabled = False
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
yield "output1", {"data": "result1"} yield "output1", {"data": "result1"}

View File

@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.waitlist_admin_routes
import backend.api.features.builder import backend.api.features.builder
import backend.api.features.builder.routes import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes import backend.api.features.chat.routes as chat_routes
@@ -283,6 +284,11 @@ app.include_router(
tags=["v2", "admin"], tags=["v2", "admin"],
prefix="/api/store", prefix="/api/store",
) )
app.include_router(
backend.api.features.admin.waitlist_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router( app.include_router(
backend.api.features.admin.credit_admin_routes.router, backend.api.features.admin.credit_admin_routes.router,
tags=["v2", "admin"], tags=["v2", "admin"],

View File

@@ -116,7 +116,6 @@ class PrintToConsoleBlock(Block):
input_schema=PrintToConsoleBlock.Input, input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output, output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"}, test_input={"text": "Hello, World!"},
is_sensitive_action=True,
test_output=[ test_output=[
("output", "Hello, World!"), ("output", "Hello, World!"),
("status", "printed"), ("status", "printed"),

View File

@@ -1,659 +0,0 @@
import json
import shlex
import uuid
from typing import Literal, Optional
from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import BaseModel, SecretStr
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class ClaudeCodeExecutionError(Exception):
"""Exception raised when Claude Code execution fails.
Carries the sandbox_id so it can be returned to the user for cleanup
when dispose_sandbox=False.
"""
def __init__(self, message: str, sandbox_id: str = ""):
super().__init__(message)
self.sandbox_id = sandbox_id
# Test credentials for E2B
TEST_E2B_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="e2b",
api_key=SecretStr("mock-e2b-api-key"),
title="Mock E2B API key",
expires_at=None,
)
TEST_E2B_CREDENTIALS_INPUT = {
"provider": TEST_E2B_CREDENTIALS.provider,
"id": TEST_E2B_CREDENTIALS.id,
"type": TEST_E2B_CREDENTIALS.type,
"title": TEST_E2B_CREDENTIALS.title,
}
# Test credentials for Anthropic
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
provider="anthropic",
api_key=SecretStr("mock-anthropic-api-key"),
title="Mock Anthropic API key",
expires_at=None,
)
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
"id": TEST_ANTHROPIC_CREDENTIALS.id,
"type": TEST_ANTHROPIC_CREDENTIALS.type,
"title": TEST_ANTHROPIC_CREDENTIALS.title,
}
class ClaudeCodeBlock(Block):
"""
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
Claude Code can create files, install tools, run commands, and perform complex
coding tasks autonomously within a secure sandbox environment.
"""
# Use base template - we'll install Claude Code ourselves for latest version
DEFAULT_TEMPLATE = "base"
class Input(BlockSchemaInput):
e2b_credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"API key for the E2B platform to create the sandbox. "
"Get one on the [e2b website](https://e2b.dev/docs)"
),
)
anthropic_credentials: CredentialsMetaInput[
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
] = CredentialsField(
description=(
"API key for Anthropic to power Claude Code. "
"Get one at [Anthropic's website](https://console.anthropic.com)"
),
)
prompt: str = SchemaField(
description=(
"The task or instruction for Claude Code to execute. "
"Claude Code can create files, install packages, run commands, "
"and perform complex coding tasks."
),
placeholder="Create a hello world index.html file",
default="",
advanced=False,
)
timeout: int = SchemaField(
description=(
"Sandbox timeout in seconds. Claude Code tasks can take "
"a while, so set this appropriately for your task complexity. "
"Note: This only applies when creating a new sandbox. "
"When reconnecting to an existing sandbox via sandbox_id, "
"the original timeout is retained."
),
default=300, # 5 minutes default
advanced=True,
)
setup_commands: list[str] = SchemaField(
description=(
"Optional shell commands to run before executing Claude Code. "
"Useful for installing dependencies or setting up the environment."
),
default_factory=list,
advanced=True,
)
working_directory: str = SchemaField(
description="Working directory for Claude Code to operate in.",
default="/home/user",
advanced=True,
)
# Session/continuation support
session_id: str = SchemaField(
description=(
"Session ID to resume a previous conversation. "
"Leave empty for a new conversation. "
"Use the session_id from a previous run to continue that conversation."
),
default="",
advanced=True,
)
sandbox_id: str = SchemaField(
description=(
"Sandbox ID to reconnect to an existing sandbox. "
"Required when resuming a session (along with session_id). "
"Use the sandbox_id from a previous run where dispose_sandbox was False."
),
default="",
advanced=True,
)
conversation_history: str = SchemaField(
description=(
"Previous conversation history to continue from. "
"Use this to restore context on a fresh sandbox if the previous one timed out. "
"Pass the conversation_history output from a previous run."
),
default="",
advanced=True,
)
dispose_sandbox: bool = SchemaField(
description=(
"Whether to dispose of the sandbox immediately after execution. "
"Set to False if you want to continue the conversation later "
"(you'll need both sandbox_id and session_id from the output)."
),
default=True,
advanced=True,
)
class FileOutput(BaseModel):
"""A file extracted from the sandbox."""
path: str
relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str
content: str
class Output(BlockSchemaOutput):
response: str = SchemaField(
description="The output/response from Claude Code execution"
)
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
description=(
"List of text files created/modified by Claude Code during this execution. "
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
)
)
conversation_history: str = SchemaField(
description=(
"Full conversation history including this turn. "
"Pass this to conversation_history input to continue on a fresh sandbox "
"if the previous sandbox timed out."
)
)
session_id: str = SchemaField(
description=(
"Session ID for this conversation. "
"Pass this back along with sandbox_id to continue the conversation."
)
)
sandbox_id: Optional[str] = SchemaField(
description=(
"ID of the sandbox instance. "
"Pass this back along with session_id to continue the conversation. "
"This is None if dispose_sandbox was True (sandbox was disposed)."
),
default=None,
)
error: str = SchemaField(description="Error message if execution failed")
def __init__(self):
super().__init__(
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
description=(
"Execute tasks using Claude Code in an E2B sandbox. "
"Claude Code can create files, install tools, run commands, "
"and perform complex coding tasks autonomously."
),
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
input_schema=ClaudeCodeBlock.Input,
output_schema=ClaudeCodeBlock.Output,
test_credentials={
"e2b_credentials": TEST_E2B_CREDENTIALS,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
},
test_input={
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
"prompt": "Create a hello world HTML file",
"timeout": 300,
"setup_commands": [],
"working_directory": "/home/user",
"session_id": "",
"sandbox_id": "",
"conversation_history": "",
"dispose_sandbox": True,
},
test_output=[
("response", "Created index.html with hello world content"),
(
"files",
[
{
"path": "/home/user/index.html",
"relative_path": "index.html",
"name": "index.html",
"content": "<html>Hello World</html>",
}
],
),
(
"conversation_history",
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content",
),
("session_id", str),
("sandbox_id", None), # None because dispose_sandbox=True in test_input
],
test_mock={
"execute_claude_code": lambda *args, **kwargs: (
"Created index.html with hello world content", # response
[
ClaudeCodeBlock.FileOutput(
path="/home/user/index.html",
relative_path="index.html",
name="index.html",
content="<html>Hello World</html>",
)
], # files
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content", # conversation_history
"test-session-id", # session_id
"sandbox_id", # sandbox_id
),
},
)
async def execute_claude_code(
self,
e2b_api_key: str,
anthropic_api_key: str,
prompt: str,
timeout: int,
setup_commands: list[str],
working_directory: str,
session_id: str,
existing_sandbox_id: str,
conversation_history: str,
dispose_sandbox: bool,
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
"""
Execute Claude Code in an E2B sandbox.
Returns:
Tuple of (response, files, conversation_history, session_id, sandbox_id)
"""
# Validate that sandbox_id is provided when resuming a session
if session_id and not existing_sandbox_id:
raise ValueError(
"sandbox_id is required when resuming a session with session_id. "
"The session state is stored in the original sandbox. "
"If the sandbox has timed out, use conversation_history instead "
"to restore context on a fresh sandbox."
)
sandbox = None
sandbox_id = ""
try:
# Either reconnect to existing sandbox or create a new one
if existing_sandbox_id:
# Reconnect to existing sandbox for conversation continuation
sandbox = await BaseAsyncSandbox.connect(
sandbox_id=existing_sandbox_id,
api_key=e2b_api_key,
)
else:
# Create new sandbox
sandbox = await BaseAsyncSandbox.create(
template=self.DEFAULT_TEMPLATE,
api_key=e2b_api_key,
timeout=timeout,
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
)
# Install Claude Code from npm (ensures we get the latest version)
install_result = await sandbox.commands.run(
"npm install -g @anthropic-ai/claude-code@latest",
timeout=120, # 2 min timeout for install
)
if install_result.exit_code != 0:
raise Exception(
f"Failed to install Claude Code: {install_result.stderr}"
)
# Run any user-provided setup commands
for cmd in setup_commands:
setup_result = await sandbox.commands.run(cmd)
if setup_result.exit_code != 0:
raise Exception(
f"Setup command failed: {cmd}\n"
f"Exit code: {setup_result.exit_code}\n"
f"Stdout: {setup_result.stdout}\n"
f"Stderr: {setup_result.stderr}"
)
# Capture sandbox_id immediately after creation/connection
# so it's available for error recovery if dispose_sandbox=False
sandbox_id = sandbox.sandbox_id
# Generate or use provided session ID
current_session_id = session_id if session_id else str(uuid.uuid4())
# Build base Claude flags
base_flags = "-p --dangerously-skip-permissions --output-format json"
# Add conversation history context if provided (for fresh sandbox continuation)
history_flag = ""
if conversation_history and not session_id:
# Inject previous conversation as context via system prompt
# Use consistent escaping via _escape_prompt helper
escaped_history = self._escape_prompt(
f"Previous conversation context: {conversation_history}"
)
history_flag = f" --append-system-prompt {escaped_history}"
# Build Claude command based on whether we're resuming or starting new
# Use shlex.quote for working_directory and session IDs to prevent injection
safe_working_dir = shlex.quote(working_directory)
if session_id:
# Resuming existing session (sandbox still alive)
safe_session_id = shlex.quote(session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --resume {safe_session_id} {base_flags}"
)
else:
# New session with specific ID
safe_current_session_id = shlex.quote(current_session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
)
# Capture timestamp before running Claude Code to filter files later
# Capture timestamp 1 second in the past to avoid race condition with file creation
timestamp_result = await sandbox.commands.run(
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
)
if timestamp_result.exit_code != 0:
raise RuntimeError(
f"Failed to capture timestamp: {timestamp_result.stderr}"
)
start_timestamp = (
timestamp_result.stdout.strip() if timestamp_result.stdout else None
)
result = await sandbox.commands.run(
claude_command,
timeout=0, # No command timeout - let sandbox timeout handle it
)
# Check for command failure
if result.exit_code != 0:
error_msg = result.stderr or result.stdout or "Unknown error"
raise Exception(
f"Claude Code command failed with exit code {result.exit_code}:\n"
f"{error_msg}"
)
raw_output = result.stdout or ""
# Parse JSON output to extract response and build conversation history
response = ""
new_conversation_history = conversation_history or ""
try:
# The JSON output contains the result
output_data = json.loads(raw_output)
response = output_data.get("result", raw_output)
# Build conversation history entry
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
except json.JSONDecodeError:
# If not valid JSON, use raw output
response = raw_output
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
# Extract files created/modified during this run
files = await self._extract_files(
sandbox, working_directory, start_timestamp
)
return (
response,
files,
new_conversation_history,
current_session_id,
sandbox_id,
)
except Exception as e:
# Wrap exception with sandbox_id so caller can access/cleanup
# the preserved sandbox when dispose_sandbox=False
raise ClaudeCodeExecutionError(str(e), sandbox_id) from e
finally:
if dispose_sandbox and sandbox:
await sandbox.kill()
async def _extract_files(
self,
sandbox: BaseAsyncSandbox,
working_directory: str,
since_timestamp: str | None = None,
) -> list["ClaudeCodeBlock.FileOutput"]:
"""
Extract text files created/modified during this Claude Code execution.
Args:
sandbox: The E2B sandbox instance
working_directory: Directory to search for files
since_timestamp: ISO timestamp - only return files modified after this time
Returns:
List of FileOutput objects with path, relative_path, name, and content
"""
files: list[ClaudeCodeBlock.FileOutput] = []
# Text file extensions we can safely read as text
text_extensions = {
".txt",
".md",
".html",
".htm",
".css",
".js",
".ts",
".jsx",
".tsx",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".py",
".rb",
".php",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".swift",
".kt",
".scala",
".sh",
".bash",
".zsh",
".sql",
".graphql",
".env",
".gitignore",
".dockerfile",
"Dockerfile",
".vue",
".svelte",
".astro",
".mdx",
".rst",
".tex",
".csv",
".log",
}
try:
# List files recursively using find command
# Exclude node_modules and .git directories, but allow hidden files
# like .env and .gitignore (they're filtered by text_extensions later)
# Filter by timestamp to only get files created/modified during this run
safe_working_dir = shlex.quote(working_directory)
timestamp_filter = ""
if since_timestamp:
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
find_result = await sandbox.commands.run(
f"find {safe_working_dir} -type f "
f"{timestamp_filter}"
f"-not -path '*/node_modules/*' "
f"-not -path '*/.git/*' "
f"2>/dev/null"
)
if find_result.stdout:
for file_path in find_result.stdout.strip().split("\n"):
if not file_path:
continue
# Check if it's a text file we can read
is_text = any(
file_path.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile")
if is_text:
try:
content = await sandbox.files.read(file_path)
# Handle bytes or string
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
# Extract filename from path
file_name = file_path.split("/")[-1]
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content=content,
)
)
except Exception:
# Skip files that can't be read
pass
except Exception:
# If file extraction fails, return empty results
pass
return files
def _escape_prompt(self, prompt: str) -> str:
"""Escape the prompt for safe shell execution."""
# Use single quotes and escape any single quotes in the prompt
escaped = prompt.replace("'", "'\"'\"'")
return f"'{escaped}'"
async def run(
self,
input_data: Input,
*,
e2b_credentials: APIKeyCredentials,
anthropic_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
try:
(
response,
files,
conversation_history,
session_id,
sandbox_id,
) = await self.execute_claude_code(
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
prompt=input_data.prompt,
timeout=input_data.timeout,
setup_commands=input_data.setup_commands,
working_directory=input_data.working_directory,
session_id=input_data.session_id,
existing_sandbox_id=input_data.sandbox_id,
conversation_history=input_data.conversation_history,
dispose_sandbox=input_data.dispose_sandbox,
)
yield "response", response
# Always yield files (empty list if none) to match Output schema
yield "files", [f.model_dump() for f in files]
# Always yield conversation_history so user can restore context on fresh sandbox
yield "conversation_history", conversation_history
# Always yield session_id so user can continue conversation
yield "session_id", session_id
# Always yield sandbox_id (None if disposed) to match Output schema
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
except ClaudeCodeExecutionError as e:
yield "error", str(e)
# If sandbox was preserved (dispose_sandbox=False), yield sandbox_id
# so user can reconnect to or clean up the orphaned sandbox
if not input_data.dispose_sandbox and e.sandbox_id:
yield "sandbox_id", e.sandbox_id
except Exception as e:
yield "error", str(e)

View File

@@ -1,68 +0,0 @@
"""Text encoding block for converting special characters to escape sequences."""
import codecs
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
class TextEncoderBlock(Block):
"""
Encodes a string by converting special characters into escape sequences.
This block is the inverse of TextDecoderBlock. It takes text containing
special characters (like newlines, tabs, etc.) and converts them into
their escape sequence representations (e.g., newline becomes \\n).
"""
class Input(BlockSchemaInput):
"""Input schema for TextEncoderBlock."""
text: str = SchemaField(
description="A string containing special characters to be encoded",
placeholder="Your text with newlines and quotes to encode",
)
class Output(BlockSchemaOutput):
"""Output schema for TextEncoderBlock."""
encoded_text: str = SchemaField(
description="The encoded text with special characters converted to escape sequences"
)
def __init__(self):
super().__init__(
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
description="Encodes a string by converting special characters into escape sequences",
categories={BlockCategory.TEXT},
input_schema=TextEncoderBlock.Input,
output_schema=TextEncoderBlock.Output,
test_input={"text": """Hello
World!
This is a "quoted" string."""},
test_output=[
(
"encoded_text",
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
)
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
Encode the input text by converting special characters to escape sequences.
Args:
input_data: The input containing the text to encode.
**kwargs: Additional keyword arguments (unused).
Yields:
The encoded text with escape sequences.
"""
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode("utf-8")
yield "encoded_text", encoded_text

View File

@@ -9,7 +9,7 @@ from typing import Any, Optional
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.execution import ExecutionStatus from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client
@@ -28,11 +28,6 @@ class ReviewDecision(BaseModel):
class HITLReviewHelper: class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations.""" """Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def check_approval(**kwargs) -> Optional[ReviewResult]:
"""Check if there's an existing approval for this node execution."""
return await get_database_manager_async_client().check_approval(**kwargs)
@staticmethod @staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]: async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database.""" """Create or retrieve a human review from the database."""
@@ -60,11 +55,11 @@ class HITLReviewHelper:
async def _handle_review_request( async def _handle_review_request(
input_data: Any, input_data: Any,
user_id: str, user_id: str,
node_id: str,
node_exec_id: str, node_exec_id: str,
graph_exec_id: str, graph_exec_id: str,
graph_id: str, graph_id: str,
graph_version: int, graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block", block_name: str = "Block",
editable: bool = False, editable: bool = False,
) -> Optional[ReviewResult]: ) -> Optional[ReviewResult]:
@@ -74,11 +69,11 @@ class HITLReviewHelper:
Args: Args:
input_data: The input data to be reviewed input_data: The input data to be reviewed
user_id: ID of the user requesting the review user_id: ID of the user requesting the review
node_id: ID of the node in the graph definition
node_exec_id: ID of the node execution node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution graph_exec_id: ID of the graph execution
graph_id: ID of the graph graph_id: ID of the graph
graph_version: Version of the graph graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data editable: Whether the reviewer can edit the data
@@ -88,41 +83,15 @@ class HITLReviewHelper:
Raises: Raises:
Exception: If review creation or status update fails Exception: If review creation or status update fails
""" """
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode) # Skip review if safe mode is disabled - return auto-approved result
# are handled by the caller: if not execution_context.safe_mode:
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
# This function only handles checking for existing approvals.
# Check if this node has already been approved (normal or auto-approval)
if approval_result := await HITLReviewHelper.check_approval(
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
node_id=node_id,
user_id=user_id,
input_data=input_data,
):
logger.info( logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - " f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
f"found existing approval"
)
# Return a new ReviewResult with the current node_exec_id but approved status
# For auto-approvals, always use current input_data
# For normal approvals, use approval_result.data unless it's None
is_auto_approval = approval_result.node_exec_id != node_exec_id
approved_data = (
input_data
if is_auto_approval
else (
approval_result.data
if approval_result.data is not None
else input_data
)
) )
return ReviewResult( return ReviewResult(
data=approved_data, data=input_data,
status=ReviewStatus.APPROVED, status=ReviewStatus.APPROVED,
message=approval_result.message, message="Auto-approved (safe mode disabled)",
processed=True, processed=True,
node_exec_id=node_exec_id, node_exec_id=node_exec_id,
) )
@@ -134,7 +103,7 @@ class HITLReviewHelper:
graph_id=graph_id, graph_id=graph_id,
graph_version=graph_version, graph_version=graph_version,
input_data=input_data, input_data=input_data,
message=block_name, # Use block_name directly as the message message=f"Review required for {block_name} execution",
editable=editable, editable=editable,
) )
@@ -160,11 +129,11 @@ class HITLReviewHelper:
async def handle_review_decision( async def handle_review_decision(
input_data: Any, input_data: Any,
user_id: str, user_id: str,
node_id: str,
node_exec_id: str, node_exec_id: str,
graph_exec_id: str, graph_exec_id: str,
graph_id: str, graph_id: str,
graph_version: int, graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block", block_name: str = "Block",
editable: bool = False, editable: bool = False,
) -> Optional[ReviewDecision]: ) -> Optional[ReviewDecision]:
@@ -174,11 +143,11 @@ class HITLReviewHelper:
Args: Args:
input_data: The input data to be reviewed input_data: The input data to be reviewed
user_id: ID of the user requesting the review user_id: ID of the user requesting the review
node_id: ID of the node in the graph definition
node_exec_id: ID of the node execution node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution graph_exec_id: ID of the graph execution
graph_id: ID of the graph graph_id: ID of the graph
graph_version: Version of the graph graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data editable: Whether the reviewer can edit the data
@@ -189,11 +158,11 @@ class HITLReviewHelper:
review_result = await HITLReviewHelper._handle_review_request( review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data, input_data=input_data,
user_id=user_id, user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id, node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id, graph_exec_id=graph_exec_id,
graph_id=graph_id, graph_id=graph_id,
graph_version=graph_version, graph_version=graph_version,
execution_context=execution_context,
block_name=block_name, block_name=block_name,
editable=editable, editable=editable,
) )

View File

@@ -97,7 +97,6 @@ class HumanInTheLoopBlock(Block):
input_data: Input, input_data: Input,
*, *,
user_id: str, user_id: str,
node_id: str,
node_exec_id: str, node_exec_id: str,
graph_exec_id: str, graph_exec_id: str,
graph_id: str, graph_id: str,
@@ -105,7 +104,7 @@ class HumanInTheLoopBlock(Block):
execution_context: ExecutionContext, execution_context: ExecutionContext,
**_kwargs, **_kwargs,
) -> BlockOutput: ) -> BlockOutput:
if not execution_context.human_in_the_loop_safe_mode: if not execution_context.safe_mode:
logger.info( logger.info(
f"HITL block skipping review for node {node_exec_id} - safe mode disabled" f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
) )
@@ -116,12 +115,12 @@ class HumanInTheLoopBlock(Block):
decision = await self.handle_review_decision( decision = await self.handle_review_decision(
input_data=input_data.data, input_data=input_data.data,
user_id=user_id, user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id, node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id, graph_exec_id=graph_exec_id,
graph_id=graph_id, graph_id=graph_id,
graph_version=graph_version, graph_version=graph_version,
block_name=input_data.name, # Use user-provided name instead of block type execution_context=execution_context,
block_name=self.name,
editable=input_data.editable, editable=input_data.editable,
) )

View File

@@ -79,10 +79,6 @@ class ModelMetadata(NamedTuple):
provider: str provider: str
context_window: int context_window: int
max_output_tokens: int | None max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]
class LlmModelMeta(EnumMeta): class LlmModelMeta(EnumMeta):
@@ -175,26 +171,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
V0_1_5_LG = "v0-1.5-lg" V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md" V0_1_0_MD = "v0-1.0-md"
@classmethod
def __get_pydantic_json_schema__(cls, schema, handler):
json_schema = handler(schema)
llm_model_metadata = {}
for model in cls:
model_name = model.value
metadata = model.metadata
llm_model_metadata[model_name] = {
"creator": metadata.creator_name,
"creator_name": metadata.creator_name,
"title": metadata.display_name,
"provider": metadata.provider,
"provider_name": metadata.provider_name,
"name": model_name,
"price_tier": metadata.price_tier,
}
json_schema["llm_model"] = True
json_schema["llm_model_metadata"] = llm_model_metadata
return json_schema
@property @property
def metadata(self) -> ModelMetadata: def metadata(self) -> ModelMetadata:
return MODEL_METADATA[self] return MODEL_METADATA[self]
@@ -214,291 +190,119 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
MODEL_METADATA = { MODEL_METADATA = {
# https://platform.openai.com/docs/models # https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2), LlmModel.O3: ModelMetadata("openai", 200000, 100000),
LlmModel.O3_MINI: ModelMetadata( LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1 LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
), # o3-mini-2025-01-31 LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
LlmModel.O1: ModelMetadata(
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata(
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
), # o1-mini-2024-09-12
# GPT-5 models # GPT-5 models
LlmModel.GPT5_2: ModelMetadata( LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3 LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
), LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata( LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2 LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
), LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
LlmModel.GPT5: ModelMetadata( LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1 LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
),
LlmModel.GPT5_MINI: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_NANO: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_CHAT: ModelMetadata(
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
),
LlmModel.GPT41: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
),
LlmModel.GPT41_MINI: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT4O_MINI: ModelMetadata( LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1 "openai", 128000, 16384
), # gpt-4o-mini-2024-07-18 ), # gpt-4o-mini-2024-07-18
LlmModel.GPT4O: ModelMetadata( LlmModel.GPT4O: ModelMetadata("openai", 128000, 16384), # gpt-4o-2024-08-06
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
), # gpt-4o-2024-08-06
LlmModel.GPT4_TURBO: ModelMetadata( LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3 "openai", 128000, 4096
), # gpt-4-turbo-2024-04-09 ), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata( LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models # https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata( LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3 "anthropic", 200000, 32000
), # claude-opus-4-1-20250805 ), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata( LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3 "anthropic", 200000, 32000
), # claude-4-opus-20250514 ), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata( LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2 "anthropic", 200000, 64000
), # claude-4-sonnet-20250514 ), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata( LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3 "anthropic", 200000, 64000
), # claude-opus-4-5-20251101 ), # claude-opus-4-5-20251101
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata( LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3 "anthropic", 200000, 64000
), # claude-sonnet-4-5-20250929 ), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata( LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2 "anthropic", 200000, 64000
), # claude-haiku-4-5-20251001 ), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata( LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2 "anthropic", 200000, 64000
), # claude-3-7-sonnet-20250219 ), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_HAIKU: ModelMetadata( LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1 "anthropic", 200000, 4096
), # claude-3-haiku-20240307 ), # claude-3-haiku-20240307
# https://docs.aimlapi.com/api-overview/model-database/text-models # https://docs.aimlapi.com/api-overview/model-database/text-models
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata( LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata("aiml_api", 32000, 8000),
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1 LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata("aiml_api", 128000, 40000),
), LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata("aiml_api", 128000, None),
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata( LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata("aiml_api", 131000, 2000),
"aiml_api", LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata("aiml_api", 128000, None),
128000,
40000,
"Llama 3.1 Nemotron 70B Instruct",
"AI/ML",
"Nvidia",
1,
),
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
),
# https://console.groq.com/docs/models # https://console.groq.com/docs/models
LlmModel.LLAMA3_3_70B: ModelMetadata( LlmModel.LLAMA3_3_70B: ModelMetadata("groq", 128000, 32768),
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1 LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 128000, 8192),
),
LlmModel.LLAMA3_1_8B: ModelMetadata(
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
),
# https://ollama.com/library # https://ollama.com/library
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata( LlmModel.OLLAMA_LLAMA3_3: ModelMetadata("ollama", 8192, None),
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1 LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192, None),
), LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, None),
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata( LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, None),
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1 LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models # https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO: ModelMetadata( LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
"open_router", LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata("open_router", 1048576, 65535),
1050000, LlmModel.GEMINI_2_5_FLASH: ModelMetadata("open_router", 1048576, 65535),
8192, LlmModel.GEMINI_2_0_FLASH: ModelMetadata("open_router", 1048576, 8192),
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata( LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router", "open_router", 1048576, 65535
1048576,
65535,
"Gemini 2.5 Flash Lite Preview 06.17",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
"open_router",
1048576,
8192,
"Gemini 2.0 Flash Lite 001",
"OpenRouter",
"Google",
1,
),
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
),
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
), ),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata("open_router", 1048576, 8192),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata("open_router", 128000, 4096),
LlmModel.DEEPSEEK_CHAT: ModelMetadata("open_router", 64000, 2048),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata("open_router", 163840, 163840),
LlmModel.PERPLEXITY_SONAR: ModelMetadata("open_router", 127000, 8000),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata("open_router", 200000, 8000),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata( LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router", "open_router",
128000, 128000,
16000, 16000,
"Sonar Deep Research",
"OpenRouter",
"Perplexity",
3,
), ),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata( LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router", "open_router", 131000, 4096
131000,
4096,
"Hermes 3 Llama 3.1 405B",
"OpenRouter",
"Nous Research",
1,
), ),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata( LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router", "open_router", 12288, 12288
12288,
12288,
"Hermes 3 Llama 3.1 70B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
LlmModel.KIMI_K2: ModelMetadata(
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
"open_router",
262144,
262144,
"Qwen 3 235B A22B Thinking 2507",
"OpenRouter",
"Qwen",
1,
),
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
), ),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata("open_router", 131072, 131072),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata("open_router", 131072, 32768),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
LlmModel.GROK_4: ModelMetadata("open_router", 256000, 256000),
LlmModel.GROK_4_FAST: ModelMetadata("open_router", 2000000, 30000),
LlmModel.GROK_4_1_FAST: ModelMetadata("open_router", 2000000, 30000),
LlmModel.GROK_CODE_FAST_1: ModelMetadata("open_router", 256000, 10000),
LlmModel.KIMI_K2: ModelMetadata("open_router", 131000, 131000),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata("open_router", 262144, 262144),
LlmModel.QWEN3_CODER: ModelMetadata("open_router", 262144, 262144),
# Llama API models # Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata( LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata("llama_api", 128000, 4028),
"llama_api", LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
128000, LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
4028, LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
"Llama 4 Scout 17B 16E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Maverick 17B 128E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
),
# v0 by Vercel models # v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1), LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1), LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1), LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
} }
DEFAULT_LLM_MODEL = LlmModel.GPT5_2 DEFAULT_LLM_MODEL = LlmModel.GPT5_2

View File

@@ -242,7 +242,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
outputs = {} outputs = {}
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -343,7 +343,7 @@ async def test_smart_decision_maker_parameter_validation():
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -409,7 +409,7 @@ async def test_smart_decision_maker_parameter_validation():
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -471,7 +471,7 @@ async def test_smart_decision_maker_parameter_validation():
outputs = {} outputs = {}
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -535,7 +535,7 @@ async def test_smart_decision_maker_parameter_validation():
outputs = {} outputs = {}
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -658,7 +658,7 @@ async def test_smart_decision_maker_raw_response_conversion():
outputs = {} outputs = {}
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -730,7 +730,7 @@ async def test_smart_decision_maker_raw_response_conversion():
outputs = {} outputs = {}
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -786,7 +786,7 @@ async def test_smart_decision_maker_raw_response_conversion():
outputs = {} outputs = {}
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests
@@ -905,7 +905,7 @@ async def test_smart_decision_maker_agent_mode():
# Create a mock execution context # Create a mock execution context
mock_execution_context = ExecutionContext( mock_execution_context = ExecutionContext(
human_in_the_loop_safe_mode=False, safe_mode=False,
) )
# Create a mock execution processor for agent mode tests # Create a mock execution processor for agent mode tests
@@ -1027,7 +1027,7 @@ async def test_smart_decision_maker_traditional_mode_default():
# Create execution context # Create execution context
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests # Create a mock execution processor for tests

View File

@@ -386,7 +386,7 @@ async def test_output_yielding_with_dynamic_fields():
outputs = {} outputs = {}
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False) mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock() mock_execution_processor = MagicMock()
async for output_name, output_value in block.run( async for output_name, output_value in block.run(
@@ -609,9 +609,7 @@ async def test_validation_errors_dont_pollute_conversation():
outputs = {} outputs = {}
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext( mock_execution_context = ExecutionContext(safe_mode=False)
human_in_the_loop_safe_mode=False
)
# Create a proper mock execution processor for agent mode # Create a proper mock execution processor for agent mode
from collections import defaultdict from collections import defaultdict

View File

@@ -1,7 +1,7 @@
import logging import logging
import os import os
import pytest_asyncio import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from backend.util.logging import configure_logging from backend.util.logging import configure_logging
@@ -19,7 +19,7 @@ if not os.getenv("PRISMA_DEBUG"):
prisma_logger.setLevel(logging.INFO) prisma_logger.setLevel(logging.INFO)
@pytest_asyncio.fixture(scope="session", loop_scope="session") @pytest.fixture(scope="session")
async def server(): async def server():
from backend.util.test import SpinTestServer from backend.util.test import SpinTestServer
@@ -27,7 +27,7 @@ async def server():
yield server yield server
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def graph_cleanup(server): async def graph_cleanup(server):
created_graph_ids = [] created_graph_ids = []
original_create_graph = server.agent_server.test_create_graph original_create_graph = server.agent_server.test_create_graph

View File

@@ -441,7 +441,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
static_output: bool = False, static_output: bool = False,
block_type: BlockType = BlockType.STANDARD, block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None, webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
is_sensitive_action: bool = False,
): ):
""" """
Initialize the block with the given schema. Initialize the block with the given schema.
@@ -474,8 +473,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.static_output = static_output self.static_output = static_output
self.block_type = block_type self.block_type = block_type
self.webhook_config = webhook_config self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
self.execution_stats: NodeExecutionStats = NodeExecutionStats() self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.requires_human_review: bool = False
if self.webhook_config: if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig): if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -623,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
input_data: BlockInput, input_data: BlockInput,
*, *,
user_id: str, user_id: str,
node_id: str,
node_exec_id: str, node_exec_id: str,
graph_exec_id: str, graph_exec_id: str,
graph_id: str, graph_id: str,
@@ -639,9 +637,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
- should_pause: True if execution should be paused for review - should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer) - input_data_to_use: The input data to use (may be modified by reviewer)
""" """
if not ( # Skip review if not required or safe mode is disabled
self.is_sensitive_action and execution_context.sensitive_action_safe_mode if not self.requires_human_review or not execution_context.safe_mode:
):
return False, input_data return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper from backend.blocks.helpers.review import HITLReviewHelper
@@ -650,11 +647,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
decision = await HITLReviewHelper.handle_review_decision( decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data, input_data=input_data,
user_id=user_id, user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id, node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id, graph_exec_id=graph_exec_id,
graph_id=graph_id, graph_id=graph_id,
graph_version=graph_version, graph_version=graph_version,
execution_context=execution_context,
block_name=self.name, block_name=self.name,
editable=True, editable=True,
) )

View File

@@ -99,15 +99,10 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.OPENAI_GPT_OSS_20B: 1, LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO: 4, LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_PRO_PREVIEW: 5, LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.MISTRAL_NEMO: 1, LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1, LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3, LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.DEEPSEEK_CHAT: 2, LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.DEEPSEEK_R1_0528: 1,
LlmModel.PERPLEXITY_SONAR: 1, LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5, LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10, LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
@@ -131,6 +126,11 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.KIMI_K2: 1, LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1, LlmModel.QWEN3_235B_A22B_THINKING: 1,
LlmModel.QWEN3_CODER: 9, LlmModel.QWEN3_CODER: 9,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.DEEPSEEK_R1_0528: 1,
# v0 by Vercel models # v0 by Vercel models
LlmModel.V0_1_5_MD: 1, LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2, LlmModel.V0_1_5_LG: 2,

View File

@@ -121,14 +121,10 @@ async def _raw_with_schema(
Supports placeholders: Supports placeholders:
- {schema_prefix}: Table/type prefix (e.g., "platform".) - {schema_prefix}: Table/type prefix (e.g., "platform".)
- {schema}: Raw schema name for application tables (e.g., platform) - {schema}: Raw schema name for application tables (e.g., platform)
- {pgvector_schema}: Schema where pgvector is installed (defaults to "public")
Note on pgvector types:
Use unqualified ::vector and <=> operator in queries. PostgreSQL resolves
these via search_path, which includes the schema where pgvector is installed
on all environments (local, CI, dev).
Args: Args:
query_template: SQL query with {schema_prefix} and/or {schema} placeholders query_template: SQL query with {schema_prefix}, {schema}, and/or {pgvector_schema} placeholders
*args: Query parameters *args: Query parameters
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE. execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
client: Optional Prisma client for transactions (only used when execute=True). client: Optional Prisma client for transactions (only used when execute=True).
@@ -139,16 +135,20 @@ async def _raw_with_schema(
Example with vector type: Example with vector type:
await execute_raw_with_schema( await execute_raw_with_schema(
'INSERT INTO {schema_prefix}"Embedding" (vec) VALUES ($1::vector)', 'INSERT INTO {schema_prefix}"Embedding" (vec) VALUES ($1::{pgvector_schema}.vector)',
embedding_data embedding_data
) )
""" """
schema = get_database_schema() schema = get_database_schema()
schema_prefix = f'"{schema}".' if schema != "public" else "" schema_prefix = f'"{schema}".' if schema != "public" else ""
# pgvector extension is typically installed in "public" schema
# On Supabase it may be in "extensions" but "public" is the common default
pgvector_schema = "public"
formatted_query = query_template.format( formatted_query = query_template.format(
schema_prefix=schema_prefix, schema_prefix=schema_prefix,
schema=schema, schema=schema,
pgvector_schema=pgvector_schema,
) )
import prisma as prisma_module import prisma as prisma_module

View File

@@ -103,18 +103,8 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
return redis.get_redis() return redis.get_redis()
def publish_event(self, event: M, channel_key: str): def publish_event(self, event: M, channel_key: str):
""" message, full_channel_name = self._serialize_message(event, channel_key)
Publish an event to Redis. Gracefully handles connection failures self.connection.publish(full_channel_name, message)
by logging the error instead of raising exceptions.
"""
try:
message, full_channel_name = self._serialize_message(event, channel_key)
self.connection.publish(full_channel_name, message)
except Exception:
logger.exception(
f"Failed to publish event to Redis channel {channel_key}. "
"Event bus operation will continue without Redis connectivity."
)
def listen_events(self, channel_key: str) -> Generator[M, None, None]: def listen_events(self, channel_key: str) -> Generator[M, None, None]:
pubsub, full_channel_name = self._get_pubsub_channel( pubsub, full_channel_name = self._get_pubsub_channel(
@@ -138,19 +128,9 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
return await redis.get_redis_async() return await redis.get_redis_async()
async def publish_event(self, event: M, channel_key: str): async def publish_event(self, event: M, channel_key: str):
""" message, full_channel_name = self._serialize_message(event, channel_key)
Publish an event to Redis. Gracefully handles connection failures connection = await self.connection
by logging the error instead of raising exceptions. await connection.publish(full_channel_name, message)
"""
try:
message, full_channel_name = self._serialize_message(event, channel_key)
connection = await self.connection
await connection.publish(full_channel_name, message)
except Exception:
logger.exception(
f"Failed to publish event to Redis channel {channel_key}. "
"Event bus operation will continue without Redis connectivity."
)
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]: async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
pubsub, full_channel_name = self._get_pubsub_channel( pubsub, full_channel_name = self._get_pubsub_channel(

View File

@@ -1,56 +0,0 @@
"""
Tests for event_bus graceful degradation when Redis is unavailable.
"""
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import BaseModel
from backend.data.event_bus import AsyncRedisEventBus
class TestEvent(BaseModel):
"""Test event model."""
message: str
class TestNotificationBus(AsyncRedisEventBus[TestEvent]):
"""Test implementation of AsyncRedisEventBus."""
Model = TestEvent
@property
def event_bus_name(self) -> str:
return "test_event_bus"
@pytest.mark.asyncio
async def test_publish_event_handles_connection_failure_gracefully():
"""Test that publish_event logs exception instead of raising when Redis is unavailable."""
bus = TestNotificationBus()
event = TestEvent(message="test message")
# Mock get_redis_async to raise connection error
with patch(
"backend.data.event_bus.redis.get_redis_async",
side_effect=ConnectionError("Authentication required."),
):
# Should not raise exception
await bus.publish_event(event, "test_channel")
@pytest.mark.asyncio
async def test_publish_event_works_with_redis_available():
"""Test that publish_event works normally when Redis is available."""
bus = TestNotificationBus()
event = TestEvent(message="test message")
# Mock successful Redis connection
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
with patch("backend.data.event_bus.redis.get_redis_async", return_value=mock_redis):
await bus.publish_event(event, "test_channel")
mock_redis.publish.assert_called_once()

View File

@@ -81,10 +81,7 @@ class ExecutionContext(BaseModel):
This includes information needed by blocks, sub-graphs, and execution management. This includes information needed by blocks, sub-graphs, and execution management.
""" """
model_config = {"extra": "ignore"} safe_mode: bool = True
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
user_timezone: str = "UTC" user_timezone: str = "UTC"
root_execution_id: Optional[str] = None root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None parent_execution_id: Optional[str] = None

View File

@@ -3,7 +3,7 @@ import logging
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from prisma.enums import SubmissionStatus from prisma.enums import SubmissionStatus
from prisma.models import ( from prisma.models import (
@@ -20,7 +20,7 @@ from prisma.types import (
AgentNodeLinkCreateInput, AgentNodeLinkCreateInput,
StoreListingVersionWhereInput, StoreListingVersionWhereInput,
) )
from pydantic import BaseModel, BeforeValidator, Field, create_model from pydantic import BaseModel, Field, create_model
from pydantic.fields import computed_field from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock from backend.blocks.agent import AgentExecutorBlock
@@ -62,31 +62,7 @@ logger = logging.getLogger(__name__)
class GraphSettings(BaseModel): class GraphSettings(BaseModel):
# Use Annotated with BeforeValidator to coerce None to default values. human_in_the_loop_safe_mode: bool | None = None
# This handles cases where the database has null values for these fields.
model_config = {"extra": "ignore"}
human_in_the_loop_safe_mode: Annotated[
bool, BeforeValidator(lambda v: v if v is not None else True)
] = True
sensitive_action_safe_mode: Annotated[
bool, BeforeValidator(lambda v: v if v is not None else False)
] = False
@classmethod
def from_graph(
cls,
graph: "GraphModel",
hitl_safe_mode: bool | None = None,
sensitive_action_safe_mode: bool = False,
) -> "GraphSettings":
# Default to True if not explicitly set
if hitl_safe_mode is None:
hitl_safe_mode = True
return cls(
human_in_the_loop_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
)
class Link(BaseDbModel): class Link(BaseDbModel):
@@ -268,14 +244,10 @@ class BaseGraph(BaseDbModel):
return any( return any(
node.block_id node.block_id
for node in self.nodes for node in self.nodes
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP if (
) node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
or node.block.requires_human_review
@computed_field )
@property
def has_sensitive_action(self) -> bool:
return any(
node.block_id for node in self.nodes if node.block.is_sensitive_action
) )
@property @property

View File

@@ -6,10 +6,10 @@ Handles all database operations for pending human reviews.
import asyncio import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional from typing import Optional
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from prisma.models import AgentNodeExecution, PendingHumanReview from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput from prisma.types import PendingHumanReviewUpdateInput
from pydantic import BaseModel from pydantic import BaseModel
@@ -17,12 +17,8 @@ from backend.api.features.executions.review.model import (
PendingHumanReviewModel, PendingHumanReviewModel,
SafeJsonData, SafeJsonData,
) )
from backend.data.execution import get_graph_execution_meta
from backend.util.json import SafeJson from backend.util.json import SafeJson
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,125 +32,6 @@ class ReviewResult(BaseModel):
node_exec_id: str node_exec_id: str
def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str:
"""Generate the special nodeExecId key for auto-approval records."""
return f"auto_approve_{graph_exec_id}_{node_id}"
async def check_approval(
node_exec_id: str,
graph_exec_id: str,
node_id: str,
user_id: str,
input_data: SafeJsonData | None = None,
) -> Optional[ReviewResult]:
"""
Check if there's an existing approval for this node execution.
Checks both:
1. Normal approval by node_exec_id (previous run of the same node execution)
2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}"
Args:
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
node_id: ID of the node definition (not execution)
user_id: ID of the user (for data isolation)
input_data: Current input data (used for auto-approvals to avoid stale data)
Returns:
ReviewResult if approval found (either normal or auto), None otherwise
"""
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
# Check for either normal approval or auto-approval in a single query
existing_review = await PendingHumanReview.prisma().find_first(
where={
"OR": [
{"nodeExecId": node_exec_id},
{"nodeExecId": auto_approve_key},
],
"status": ReviewStatus.APPROVED,
"userId": user_id,
},
)
if existing_review:
is_auto_approval = existing_review.nodeExecId == auto_approve_key
logger.info(
f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} "
f"(exec: {node_exec_id}) in execution {graph_exec_id}"
)
# For auto-approvals, use current input_data to avoid replaying stale payload
# For normal approvals, use the stored payload (which may have been edited)
return ReviewResult(
data=(
input_data
if is_auto_approval and input_data is not None
else existing_review.payload
),
status=ReviewStatus.APPROVED,
message=(
"Auto-approved (user approved all future actions for this node)"
if is_auto_approval
else existing_review.reviewMessage or ""
),
processed=True,
node_exec_id=existing_review.nodeExecId,
)
return None
async def create_auto_approval_record(
user_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
node_id: str,
payload: SafeJsonData,
) -> None:
"""
Create an auto-approval record for a node in this execution.
This is stored as a PendingHumanReview with a special nodeExecId pattern
and status=APPROVED, so future executions of the same node can skip review.
Raises:
ValueError: If the graph execution doesn't belong to the user
"""
# Validate that the graph execution belongs to this user (defense in depth)
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise ValueError(
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
)
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
await PendingHumanReview.prisma().upsert(
where={"nodeExecId": auto_approve_key},
data={
"create": {
"nodeExecId": auto_approve_key,
"userId": user_id,
"graphExecId": graph_exec_id,
"graphId": graph_id,
"graphVersion": graph_version,
"payload": SafeJson(payload),
"instructions": "Auto-approval record",
"editable": False,
"status": ReviewStatus.APPROVED,
"processed": True,
"reviewedAt": datetime.now(timezone.utc),
},
"update": {}, # Already exists, no update needed
},
)
async def get_or_create_human_review( async def get_or_create_human_review(
user_id: str, user_id: str,
node_exec_id: str, node_exec_id: str,
@@ -231,89 +108,6 @@ async def get_or_create_human_review(
) )
async def get_pending_review_by_node_exec_id(
node_exec_id: str, user_id: str
) -> Optional["PendingHumanReviewModel"]:
"""
Get a pending review by its node execution ID.
Args:
node_exec_id: The node execution ID to look up
user_id: User ID for authorization (only returns if review belongs to this user)
Returns:
The pending review if found and belongs to user, None otherwise
"""
review = await PendingHumanReview.prisma().find_first(
where={
"nodeExecId": node_exec_id,
"userId": user_id,
"status": ReviewStatus.WAITING,
}
)
if not review:
return None
# Local import to avoid event loop conflicts in tests
from backend.data.execution import get_node_execution
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
return PendingHumanReviewModel.from_db(review, node_id=node_id)
async def get_reviews_by_node_exec_ids(
node_exec_ids: list[str], user_id: str
) -> dict[str, "PendingHumanReviewModel"]:
"""
Get multiple reviews by their node execution IDs regardless of status.
Unlike get_pending_reviews_by_node_exec_ids, this returns reviews in any status
(WAITING, APPROVED, REJECTED). Used for validation in idempotent operations.
Args:
node_exec_ids: List of node execution IDs to look up
user_id: User ID for authorization (only returns reviews belonging to this user)
Returns:
Dictionary mapping node_exec_id -> PendingHumanReviewModel for found reviews
"""
if not node_exec_ids:
return {}
reviews = await PendingHumanReview.prisma().find_many(
where={
"nodeExecId": {"in": node_exec_ids},
"userId": user_id,
}
)
if not reviews:
return {}
# Batch fetch all node executions to avoid N+1 queries
node_exec_ids_to_fetch = [review.nodeExecId for review in reviews]
node_execs = await AgentNodeExecution.prisma().find_many(
where={"id": {"in": node_exec_ids_to_fetch}},
include={"Node": True},
)
# Create mapping from node_exec_id to node_id
node_exec_id_to_node_id = {
node_exec.id: node_exec.agentNodeId for node_exec in node_execs
}
result = {}
for review in reviews:
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
review, node_id=node_id
)
return result
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool: async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
""" """
Check if a graph execution has any pending reviews. Check if a graph execution has any pending reviews.
@@ -343,11 +137,8 @@ async def get_pending_reviews_for_user(
page_size: Number of reviews per page page_size: Number of reviews per page
Returns: Returns:
List of pending review models with node_id included List of pending review models
""" """
# Local import to avoid event loop conflicts in tests
from backend.data.execution import get_node_execution
# Calculate offset for pagination # Calculate offset for pagination
offset = (page - 1) * page_size offset = (page - 1) * page_size
@@ -358,14 +149,7 @@ async def get_pending_reviews_for_user(
take=page_size, take=page_size,
) )
# Fetch node_id for each review from NodeExecution return [PendingHumanReviewModel.from_db(review) for review in reviews]
result = []
for review in reviews:
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
return result
async def get_pending_reviews_for_execution( async def get_pending_reviews_for_execution(
@@ -379,11 +163,8 @@ async def get_pending_reviews_for_execution(
user_id: User ID for security validation user_id: User ID for security validation
Returns: Returns:
List of pending review models with node_id included List of pending review models
""" """
# Local import to avoid event loop conflicts in tests
from backend.data.execution import get_node_execution
reviews = await PendingHumanReview.prisma().find_many( reviews = await PendingHumanReview.prisma().find_many(
where={ where={
"userId": user_id, "userId": user_id,
@@ -393,14 +174,7 @@ async def get_pending_reviews_for_execution(
order={"createdAt": "asc"}, order={"createdAt": "asc"},
) )
# Fetch node_id for each review from NodeExecution return [PendingHumanReviewModel.from_db(review) for review in reviews]
result = []
for review in reviews:
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
return result
async def process_all_reviews_for_execution( async def process_all_reviews_for_execution(
@@ -409,68 +183,38 @@ async def process_all_reviews_for_execution(
) -> dict[str, PendingHumanReviewModel]: ) -> dict[str, PendingHumanReviewModel]:
"""Process all pending reviews for an execution with approve/reject decisions. """Process all pending reviews for an execution with approve/reject decisions.
Handles race conditions gracefully: if a review was already processed with the
same decision by a concurrent request, it's treated as success rather than error.
Args: Args:
user_id: User ID for ownership validation user_id: User ID for ownership validation
review_decisions: Map of node_exec_id -> (status, reviewed_data, message) review_decisions: Map of node_exec_id -> (status, reviewed_data, message)
Returns: Returns:
Dict of node_exec_id -> updated review model (includes already-processed reviews) Dict of node_exec_id -> updated review model
""" """
if not review_decisions: if not review_decisions:
return {} return {}
node_exec_ids = list(review_decisions.keys()) node_exec_ids = list(review_decisions.keys())
# Get all reviews (both WAITING and already processed) for the user # Get all reviews for validation
all_reviews = await PendingHumanReview.prisma().find_many( reviews = await PendingHumanReview.prisma().find_many(
where={ where={
"nodeExecId": {"in": node_exec_ids}, "nodeExecId": {"in": node_exec_ids},
"userId": user_id, "userId": user_id,
"status": ReviewStatus.WAITING,
}, },
) )
# Separate into pending and already-processed reviews # Validate all reviews can be processed
reviews_to_process = [] if len(reviews) != len(node_exec_ids):
already_processed = [] missing_ids = set(node_exec_ids) - {review.nodeExecId for review in reviews}
for review in all_reviews:
if review.status == ReviewStatus.WAITING:
reviews_to_process.append(review)
else:
already_processed.append(review)
# Check for truly missing reviews (not found at all)
found_ids = {review.nodeExecId for review in all_reviews}
missing_ids = set(node_exec_ids) - found_ids
if missing_ids:
raise ValueError( raise ValueError(
f"Reviews not found or access denied: {', '.join(missing_ids)}" f"Reviews not found, access denied, or not in WAITING status: {', '.join(missing_ids)}"
) )
# Validate already-processed reviews have compatible status (same decision) # Create parallel update tasks
# This handles race conditions where another request processed the same reviews
for review in already_processed:
requested_status = review_decisions[review.nodeExecId][0]
if review.status != requested_status:
raise ValueError(
f"Review {review.nodeExecId} was already processed with status "
f"{review.status}, cannot change to {requested_status}"
)
# Log if we're handling a race condition (some reviews already processed)
if already_processed:
already_processed_ids = [r.nodeExecId for r in already_processed]
logger.info(
f"Race condition handled: {len(already_processed)} review(s) already "
f"processed by concurrent request: {already_processed_ids}"
)
# Create parallel update tasks for reviews that still need processing
update_tasks = [] update_tasks = []
for review in reviews_to_process: for review in reviews:
new_status, reviewed_data, message = review_decisions[review.nodeExecId] new_status, reviewed_data, message = review_decisions[review.nodeExecId]
has_data_changes = reviewed_data is not None and reviewed_data != review.payload has_data_changes = reviewed_data is not None and reviewed_data != review.payload
@@ -495,27 +239,16 @@ async def process_all_reviews_for_execution(
update_tasks.append(task) update_tasks.append(task)
# Execute all updates in parallel and get updated reviews # Execute all updates in parallel and get updated reviews
updated_reviews = await asyncio.gather(*update_tasks) if update_tasks else [] updated_reviews = await asyncio.gather(*update_tasks)
# Note: Execution resumption is now handled at the API layer after ALL reviews # Note: Execution resumption is now handled at the API layer after ALL reviews
# for an execution are processed (both approved and rejected) # for an execution are processed (both approved and rejected)
# Fetch node_id for each review and return as dict for easy access # Return as dict for easy access
# Local import to avoid event loop conflicts in tests return {
from backend.data.execution import get_node_execution review.nodeExecId: PendingHumanReviewModel.from_db(review)
for review in updated_reviews
# Combine updated reviews with already-processed ones (for idempotent response) }
all_result_reviews = list(updated_reviews) + already_processed
result = {}
for review in all_result_reviews:
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
review, node_id=node_id
)
return result
async def update_review_processed_status(node_exec_id: str, processed: bool) -> None: async def update_review_processed_status(node_exec_id: str, processed: bool) -> None:
@@ -523,44 +256,3 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) ->
await PendingHumanReview.prisma().update( await PendingHumanReview.prisma().update(
where={"nodeExecId": node_exec_id}, data={"processed": processed} where={"nodeExecId": node_exec_id}, data={"processed": processed}
) )
async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int:
"""
Cancel all pending reviews for a graph execution (e.g., when execution is stopped).
Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped.
Args:
graph_exec_id: The graph execution ID
user_id: User ID who owns the execution (for security validation)
Returns:
Number of reviews cancelled
Raises:
ValueError: If the graph execution doesn't belong to the user
"""
# Validate user ownership before cancelling reviews
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise ValueError(
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
)
result = await PendingHumanReview.prisma().update_many(
where={
"graphExecId": graph_exec_id,
"userId": user_id,
"status": ReviewStatus.WAITING,
},
data={
"status": ReviewStatus.REJECTED,
"reviewMessage": "Execution was stopped by user",
"processed": True,
"reviewedAt": datetime.now(timezone.utc),
},
)
return result

View File

@@ -36,7 +36,7 @@ def sample_db_review():
return mock_review return mock_review
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_get_or_create_human_review_new( async def test_get_or_create_human_review_new(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
sample_db_review, sample_db_review,
@@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new(
sample_db_review.status = ReviewStatus.WAITING sample_db_review.status = ReviewStatus.WAITING
sample_db_review.processed = False sample_db_review.processed = False
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review) mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
result = await get_or_create_human_review( result = await get_or_create_human_review(
user_id="test-user-123", user_id="test-user-123",
@@ -64,7 +64,7 @@ async def test_get_or_create_human_review_new(
assert result is None assert result is None
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_get_or_create_human_review_approved( async def test_get_or_create_human_review_approved(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
sample_db_review, sample_db_review,
@@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved(
sample_db_review.processed = False sample_db_review.processed = False
sample_db_review.reviewMessage = "Looks good" sample_db_review.reviewMessage = "Looks good"
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review) mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
result = await get_or_create_human_review( result = await get_or_create_human_review(
user_id="test-user-123", user_id="test-user-123",
@@ -96,7 +96,7 @@ async def test_get_or_create_human_review_approved(
assert result.message == "Looks good" assert result.message == "Looks good"
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_has_pending_reviews_for_graph_exec_true( async def test_has_pending_reviews_for_graph_exec_true(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
): ):
@@ -109,7 +109,7 @@ async def test_has_pending_reviews_for_graph_exec_true(
assert result is True assert result is True
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_has_pending_reviews_for_graph_exec_false( async def test_has_pending_reviews_for_graph_exec_false(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
): ):
@@ -122,7 +122,7 @@ async def test_has_pending_reviews_for_graph_exec_false(
assert result is False assert result is False
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_get_pending_reviews_for_user( async def test_get_pending_reviews_for_user(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
sample_db_review, sample_db_review,
@@ -131,19 +131,10 @@ async def test_get_pending_reviews_for_user(
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review]) mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await get_pending_reviews_for_user("test_user", page=2, page_size=10) result = await get_pending_reviews_for_user("test_user", page=2, page_size=10)
assert len(result) == 1 assert len(result) == 1
assert result[0].node_exec_id == "test_node_123" assert result[0].node_exec_id == "test_node_123"
assert result[0].node_id == "test_node_def_789"
# Verify pagination parameters # Verify pagination parameters
call_args = mock_find_many.return_value.find_many.call_args call_args = mock_find_many.return_value.find_many.call_args
@@ -151,7 +142,7 @@ async def test_get_pending_reviews_for_user(
assert call_args.kwargs["take"] == 10 assert call_args.kwargs["take"] == 10
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_get_pending_reviews_for_execution( async def test_get_pending_reviews_for_execution(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
sample_db_review, sample_db_review,
@@ -160,21 +151,12 @@ async def test_get_pending_reviews_for_execution(
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review]) mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await get_pending_reviews_for_execution( result = await get_pending_reviews_for_execution(
"test_graph_exec_456", "test-user-123" "test_graph_exec_456", "test-user-123"
) )
assert len(result) == 1 assert len(result) == 1
assert result[0].graph_exec_id == "test_graph_exec_456" assert result[0].graph_exec_id == "test_graph_exec_456"
assert result[0].node_id == "test_node_def_789"
# Verify it filters by execution and user # Verify it filters by execution and user
call_args = mock_find_many.return_value.find_many.call_args call_args = mock_find_many.return_value.find_many.call_args
@@ -184,7 +166,7 @@ async def test_get_pending_reviews_for_execution(
assert where_clause["status"] == ReviewStatus.WAITING assert where_clause["status"] == ReviewStatus.WAITING
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_process_all_reviews_for_execution_success( async def test_process_all_reviews_for_execution_success(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
sample_db_review, sample_db_review,
@@ -219,14 +201,6 @@ async def test_process_all_reviews_for_execution_success(
new=AsyncMock(return_value=[updated_review]), new=AsyncMock(return_value=[updated_review]),
) )
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await process_all_reviews_for_execution( result = await process_all_reviews_for_execution(
user_id="test-user-123", user_id="test-user-123",
review_decisions={ review_decisions={
@@ -237,10 +211,9 @@ async def test_process_all_reviews_for_execution_success(
assert len(result) == 1 assert len(result) == 1
assert "test_node_123" in result assert "test_node_123" in result
assert result["test_node_123"].status == ReviewStatus.APPROVED assert result["test_node_123"].status == ReviewStatus.APPROVED
assert result["test_node_123"].node_id == "test_node_def_789"
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_process_all_reviews_for_execution_validation_errors( async def test_process_all_reviews_for_execution_validation_errors(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
): ):
@@ -260,7 +233,7 @@ async def test_process_all_reviews_for_execution_validation_errors(
) )
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_process_all_reviews_edit_permission_error( async def test_process_all_reviews_edit_permission_error(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
sample_db_review, sample_db_review,
@@ -286,7 +259,7 @@ async def test_process_all_reviews_edit_permission_error(
) )
@pytest.mark.asyncio(loop_scope="function") @pytest.mark.asyncio
async def test_process_all_reviews_mixed_approval_rejection( async def test_process_all_reviews_mixed_approval_rejection(
mocker: pytest_mock.MockFixture, mocker: pytest_mock.MockFixture,
sample_db_review, sample_db_review,
@@ -356,14 +329,6 @@ async def test_process_all_reviews_mixed_approval_rejection(
new=AsyncMock(return_value=[approved_review, rejected_review]), new=AsyncMock(return_value=[approved_review, rejected_review]),
) )
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await process_all_reviews_for_execution( result = await process_all_reviews_for_execution(
user_id="test-user-123", user_id="test-user-123",
review_decisions={ review_decisions={
@@ -375,5 +340,3 @@ async def test_process_all_reviews_mixed_approval_rejection(
assert len(result) == 2 assert len(result) == 2
assert "test_node_123" in result assert "test_node_123" in result
assert "test_node_456" in result assert "test_node_456" in result
assert result["test_node_123"].node_id == "test_node_def_789"
assert result["test_node_456"].node_id == "test_node_def_789"

View File

@@ -211,6 +211,22 @@ class AgentRejectionData(BaseNotificationData):
return value return value
class WaitlistLaunchData(BaseNotificationData):
"""Notification data for when an agent from a waitlist is launched."""
agent_name: str
waitlist_name: str
store_url: str
launched_at: datetime
@field_validator("launched_at")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
NotificationData = Annotated[ NotificationData = Annotated[
Union[ Union[
AgentRunData, AgentRunData,
@@ -223,6 +239,7 @@ NotificationData = Annotated[
DailySummaryData, DailySummaryData,
RefundRequestData, RefundRequestData,
BaseSummaryData, BaseSummaryData,
WaitlistLaunchData,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
@@ -273,6 +290,7 @@ def get_notif_data_type(
NotificationType.REFUND_PROCESSED: RefundRequestData, NotificationType.REFUND_PROCESSED: RefundRequestData,
NotificationType.AGENT_APPROVED: AgentApprovalData, NotificationType.AGENT_APPROVED: AgentApprovalData,
NotificationType.AGENT_REJECTED: AgentRejectionData, NotificationType.AGENT_REJECTED: AgentRejectionData,
NotificationType.WAITLIST_LAUNCH: WaitlistLaunchData,
}[notification_type] }[notification_type]
@@ -318,6 +336,7 @@ class NotificationTypeOverride:
NotificationType.REFUND_PROCESSED: QueueType.ADMIN, NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE, NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE, NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
NotificationType.WAITLIST_LAUNCH: QueueType.IMMEDIATE,
} }
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE) return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
@@ -337,6 +356,7 @@ class NotificationTypeOverride:
NotificationType.REFUND_PROCESSED: "refund_processed.html", NotificationType.REFUND_PROCESSED: "refund_processed.html",
NotificationType.AGENT_APPROVED: "agent_approved.html", NotificationType.AGENT_APPROVED: "agent_approved.html",
NotificationType.AGENT_REJECTED: "agent_rejected.html", NotificationType.AGENT_REJECTED: "agent_rejected.html",
NotificationType.WAITLIST_LAUNCH: "waitlist_launch.html",
}[self.notification_type] }[self.notification_type]
@property @property
@@ -354,6 +374,7 @@ class NotificationTypeOverride:
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed", NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!", NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates", NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
NotificationType.WAITLIST_LAUNCH: "🚀 {{data.agent_name}} is now available!",
}[self.notification_type] }[self.notification_type]

View File

@@ -50,8 +50,6 @@ from backend.data.graph import (
validate_graph_execution_permissions, validate_graph_execution_permissions,
) )
from backend.data.human_review import ( from backend.data.human_review import (
cancel_pending_reviews_for_execution,
check_approval,
get_or_create_human_review, get_or_create_human_review,
has_pending_reviews_for_graph_exec, has_pending_reviews_for_graph_exec,
update_review_processed_status, update_review_processed_status,
@@ -192,8 +190,6 @@ class DatabaseManager(AppService):
get_user_notification_preference = _(get_user_notification_preference) get_user_notification_preference = _(get_user_notification_preference)
# Human In The Loop # Human In The Loop
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
check_approval = _(check_approval)
get_or_create_human_review = _(get_or_create_human_review) get_or_create_human_review = _(get_or_create_human_review)
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec) has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
update_review_processed_status = _(update_review_processed_status) update_review_processed_status = _(update_review_processed_status)
@@ -317,8 +313,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
set_execution_kv_data = d.set_execution_kv_data set_execution_kv_data = d.set_execution_kv_data
# Human In The Loop # Human In The Loop
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
check_approval = d.check_approval
get_or_create_human_review = d.get_or_create_human_review get_or_create_human_review = d.get_or_create_human_review
update_review_processed_status = d.update_review_processed_status update_review_processed_status = d.update_review_processed_status

View File

@@ -309,7 +309,7 @@ def ensure_embeddings_coverage():
# Process in batches until no more missing embeddings # Process in batches until no more missing embeddings
while True: while True:
result = db_client.backfill_missing_embeddings(batch_size=100) result = db_client.backfill_missing_embeddings(batch_size=10)
total_processed += result["processed"] total_processed += result["processed"]
total_success += result["success"] total_success += result["success"]

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data import human_review as human_review_db
from backend.data import onboarding as onboarding_db from backend.data import onboarding as onboarding_db
from backend.data import user as user_db from backend.data import user as user_db
from backend.data.block import ( from backend.data.block import (
@@ -750,27 +749,9 @@ async def stop_graph_execution(
if graph_exec.status in [ if graph_exec.status in [
ExecutionStatus.QUEUED, ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE, ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
]: ]:
# If the graph is queued/incomplete/paused for review, terminate immediately # If the graph is still on the queue, we can prevent them from being executed
# No need to wait for executor since it's not actively running # by setting the status to TERMINATED.
# If graph is in REVIEW status, clean up pending reviews before terminating
if graph_exec.status == ExecutionStatus.REVIEW:
# Use human_review_db if Prisma connected, else database manager
review_db = (
human_review_db
if prisma.is_connected()
else get_database_manager_async_client()
)
# Mark all pending reviews as rejected/cancelled
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
graph_exec_id, user_id
)
logger.info(
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
)
graph_exec.status = ExecutionStatus.TERMINATED graph_exec.status = ExecutionStatus.TERMINATED
await asyncio.gather( await asyncio.gather(
@@ -892,8 +873,11 @@ async def add_graph_execution(
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id) settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
execution_context = ExecutionContext( execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode, safe_mode=(
sensitive_action_safe_mode=settings.sensitive_action_safe_mode, settings.human_in_the_loop_safe_mode
if settings.human_in_the_loop_safe_mode is not None
else True
),
user_timezone=( user_timezone=(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC" user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
), ),
@@ -906,28 +890,9 @@ async def add_graph_execution(
nodes_to_skip=nodes_to_skip, nodes_to_skip=nodes_to_skip,
execution_context=execution_context, execution_context=execution_context,
) )
logger.info(f"Queueing execution {graph_exec.id}") logger.info(f"Publishing execution {graph_exec.id} to execution queue")
# Update execution status to QUEUED BEFORE publishing to prevent race condition
# where two concurrent requests could both publish the same execution
updated_exec = await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.QUEUED,
)
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
# If another request already updated the status, this execution will not be QUEUED
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
logger.warning(
f"Skipping queue publish for execution {graph_exec.id} - "
f"status update failed or execution already queued by another request"
)
return graph_exec
graph_exec.status = ExecutionStatus.QUEUED
# Publish to execution queue for executor to pick up # Publish to execution queue for executor to pick up
# This happens AFTER status update to ensure only one request publishes
exec_queue = await get_async_execution_queue() exec_queue = await get_async_execution_queue()
await exec_queue.publish_message( await exec_queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY, routing_key=GRAPH_EXECUTION_ROUTING_KEY,
@@ -935,6 +900,13 @@ async def add_graph_execution(
exchange=GRAPH_EXECUTION_EXCHANGE, exchange=GRAPH_EXECUTION_EXCHANGE,
) )
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue") logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
# Update execution status to QUEUED
graph_exec.status = ExecutionStatus.QUEUED
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=graph_exec.status,
)
except BaseException as e: except BaseException as e:
err = str(e) or type(e).__name__ err = str(e) or type(e).__name__
if not graph_exec: if not graph_exec:

View File

@@ -4,7 +4,6 @@ import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
from backend.data.execution import ExecutionStatus
from backend.util.mock import MockObject from backend.util.mock import MockObject
@@ -347,7 +346,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes) mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123" mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock() mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
# Mock the queue and event bus # Mock the queue and event bus
@@ -388,7 +386,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
mock_user.timezone = "UTC" mock_user.timezone = "UTC"
mock_settings = mocker.MagicMock() mock_settings = mocker.MagicMock()
mock_settings.human_in_the_loop_safe_mode = True mock_settings.human_in_the_loop_safe_mode = True
mock_settings.sensitive_action_safe_mode = False
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user) mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings) mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
@@ -613,7 +610,6 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes) mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123" mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = [] mock_graph_exec.node_executions = []
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
# Track what's passed to to_graph_execution_entry # Track what's passed to to_graph_execution_entry
captured_kwargs = {} captured_kwargs = {}
@@ -655,7 +651,6 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
mock_user.timezone = "UTC" mock_user.timezone = "UTC"
mock_settings = mocker.MagicMock() mock_settings = mocker.MagicMock()
mock_settings.human_in_the_loop_safe_mode = True mock_settings.human_in_the_loop_safe_mode = True
mock_settings.sensitive_action_safe_mode = False
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user) mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings) mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
@@ -673,232 +668,3 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
# Verify nodes_to_skip was passed to to_graph_execution_entry # Verify nodes_to_skip was passed to to_graph_execution_entry
assert "nodes_to_skip" in captured_kwargs assert "nodes_to_skip" in captured_kwargs
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
@pytest.mark.asyncio
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
mocker: MockerFixture,
):
"""Test that stopping an execution in REVIEW status cancels pending reviews."""
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
from backend.executor.utils import stop_graph_execution
user_id = "test-user"
graph_exec_id = "test-exec-123"
# Mock graph execution in REVIEW status
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_graph_exec.id = graph_exec_id
mock_graph_exec.status = ExecutionStatus.REVIEW
# Mock dependencies
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_queue_client = mocker.AsyncMock()
mock_get_queue.return_value = mock_queue_client
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_prisma.is_connected.return_value = True
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
return_value=2 # 2 reviews cancelled
)
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
mock_get_event_bus.return_value = mock_event_bus
mock_get_child_executions = mocker.patch(
"backend.executor.utils._get_child_executions"
)
mock_get_child_executions.return_value = [] # No children
# Call stop_graph_execution with timeout to allow status check
await stop_graph_execution(
user_id=user_id,
graph_exec_id=graph_exec_id,
wait_timeout=1.0, # Wait to allow status check
cascade=True,
)
# Verify pending reviews were cancelled
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
graph_exec_id, user_id
)
# Verify execution status was updated to TERMINATED
mock_execution_db.update_graph_execution_stats.assert_called_once()
call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1]
assert call_kwargs["graph_exec_id"] == graph_exec_id
assert call_kwargs["status"] == ExecutionStatus.TERMINATED
@pytest.mark.asyncio
async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected(
mocker: MockerFixture,
):
"""Test that stop uses database manager when Prisma is not connected."""
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
from backend.executor.utils import stop_graph_execution
user_id = "test-user"
graph_exec_id = "test-exec-456"
# Mock graph execution in REVIEW status
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_graph_exec.id = graph_exec_id
mock_graph_exec.status = ExecutionStatus.REVIEW
# Mock dependencies
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_queue_client = mocker.AsyncMock()
mock_get_queue.return_value = mock_queue_client
# Prisma is NOT connected
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_prisma.is_connected.return_value = False
# Mock database manager client
mock_get_db_manager = mocker.patch(
"backend.executor.utils.get_database_manager_async_client"
)
mock_db_manager = mocker.AsyncMock()
mock_db_manager.get_graph_execution_meta = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock(
return_value=3 # 3 reviews cancelled
)
mock_db_manager.update_graph_execution_stats = mocker.AsyncMock()
mock_get_db_manager.return_value = mock_db_manager
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
mock_get_event_bus.return_value = mock_event_bus
mock_get_child_executions = mocker.patch(
"backend.executor.utils._get_child_executions"
)
mock_get_child_executions.return_value = [] # No children
# Call stop_graph_execution with timeout
await stop_graph_execution(
user_id=user_id,
graph_exec_id=graph_exec_id,
wait_timeout=1.0,
cascade=True,
)
# Verify database manager was used for cancel_pending_reviews
mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with(
graph_exec_id, user_id
)
# Verify execution status was updated via database manager
mock_db_manager.update_graph_execution_stats.assert_called_once()
@pytest.mark.asyncio
async def test_stop_graph_execution_cascades_to_child_with_reviews(
mocker: MockerFixture,
):
"""Test that stopping parent execution cascades to children and cancels their reviews."""
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
from backend.executor.utils import stop_graph_execution
user_id = "test-user"
parent_exec_id = "parent-exec"
child_exec_id = "child-exec"
# Mock parent execution in RUNNING status
mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_parent_exec.id = parent_exec_id
mock_parent_exec.status = ExecutionStatus.RUNNING
# Mock child execution in REVIEW status
mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_child_exec.id = child_exec_id
mock_child_exec.status = ExecutionStatus.REVIEW
# Mock dependencies
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_queue_client = mocker.AsyncMock()
mock_get_queue.return_value = mock_queue_client
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_prisma.is_connected.return_value = True
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
return_value=1 # 1 child review cancelled
)
# Mock execution_db to return different status based on which execution is queried
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
# Track call count to simulate status transition
call_count = {"count": 0}
async def get_exec_meta_side_effect(execution_id, user_id):
call_count["count"] += 1
if execution_id == parent_exec_id:
# After a few calls (child processing happens), transition parent to TERMINATED
# This simulates the executor service processing the stop request
if call_count["count"] > 3:
mock_parent_exec.status = ExecutionStatus.TERMINATED
return mock_parent_exec
elif execution_id == child_exec_id:
return mock_child_exec
return None
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
side_effect=get_exec_meta_side_effect
)
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
mock_get_event_bus.return_value = mock_event_bus
# Mock _get_child_executions to return the child
mock_get_child_executions = mocker.patch(
"backend.executor.utils._get_child_executions"
)
def get_children_side_effect(parent_id):
if parent_id == parent_exec_id:
return [mock_child_exec]
return []
mock_get_child_executions.side_effect = get_children_side_effect
# Call stop_graph_execution on parent with cascade=True
await stop_graph_execution(
user_id=user_id,
graph_exec_id=parent_exec_id,
wait_timeout=1.0,
cascade=True,
)
# Verify child reviews were cancelled
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
child_exec_id, user_id
)
# Verify both parent and child status updates
assert mock_execution_db.update_graph_execution_stats.call_count >= 1

View File

@@ -0,0 +1,59 @@
{# Waitlist Launch Notification Email Template #}
{#
Template variables:
data.agent_name: the name of the launched agent
data.waitlist_name: the name of the waitlist the user joined
data.store_url: URL to view the agent in the store
data.launched_at: when the agent was launched
Subject: {{ data.agent_name }} is now available!
#}
{% block content %}
<h1 style="color: #7c3aed; font-size: 32px; font-weight: 700; margin: 0 0 24px 0; text-align: center;">
The wait is over!
</h1>
<p style="color: #586069; font-size: 18px; text-align: center; margin: 0 0 24px 0;">
<strong>'{{ data.agent_name }}'</strong> is now live in the AutoGPT Store!
</p>
<div style="height: 32px; background: transparent;"></div>
<div style="background: #f3e8ff; border: 1px solid #d8b4fe; border-radius: 8px; padding: 20px; margin: 0;">
<h3 style="color: #6b21a8; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
You're one of the first to know!
</h3>
<p style="color: #6b21a8; margin: 0; font-size: 16px; line-height: 1.5;">
You signed up for the <strong>{{ data.waitlist_name }}</strong> waitlist, and we're excited to let you know that this agent is now ready for you to use.
</p>
</div>
<div style="height: 32px; background: transparent;"></div>
<div style="text-align: center; margin: 24px 0;">
<a href="{{ data.store_url }}" style="display: inline-block; background: linear-gradient(135deg, #7c3aed 0%, #5b21b6 100%); color: white; text-decoration: none; padding: 14px 28px; border-radius: 6px; font-weight: 600; font-size: 16px;">
Get {{ data.agent_name }} Now
</a>
</div>
<div style="height: 32px; background: transparent;"></div>
<div style="background: #d1ecf1; border: 1px solid #bee5eb; border-radius: 8px; padding: 20px; margin: 0;">
<h3 style="color: #0c5460; font-size: 16px; font-weight: 600; margin: 0 0 12px 0;">
What can you do now?
</h3>
<ul style="color: #0c5460; margin: 0; padding-left: 18px; font-size: 16px; line-height: 1.6;">
<li>Visit the store to learn more about what this agent can do</li>
<li>Install and start using the agent right away</li>
<li>Share it with others who might find it useful</li>
</ul>
</div>
<div style="height: 32px; background: transparent;"></div>
<p style="color: #6a737d; font-size: 14px; text-align: center; margin: 24px 0;">
Thank you for helping us prioritize what to build! Your interest made this happen.
</p>
{% endblock %}

View File

@@ -350,19 +350,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Whether to mark failed scans as clean or not", description="Whether to mark failed scans as clean or not",
) )
agentgenerator_host: str = Field(
default="",
description="The host for the Agent Generator service (empty to use built-in)",
)
agentgenerator_port: int = Field(
default=8000,
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=120,
description="The timeout in seconds for Agent Generator service requests",
)
enable_example_blocks: bool = Field( enable_example_blocks: bool = Field(
default=False, default=False,
description="Whether to enable example blocks in production", description="Whether to enable example blocks in production",
@@ -679,12 +666,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
default="https://cloud.langfuse.com", description="Langfuse host URL" default="https://cloud.langfuse.com", description="Langfuse host URL"
) )
# PostHog analytics
posthog_api_key: str = Field(default="", description="PostHog API key")
posthog_host: str = Field(
default="https://eu.i.posthog.com", description="PostHog host URL"
)
# Add more secret fields as needed # Add more secret fields as needed
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env",

View File

@@ -1,4 +1,3 @@
import asyncio
import inspect import inspect
import logging import logging
import time import time
@@ -59,11 +58,6 @@ class SpinTestServer:
self.db_api.__exit__(exc_type, exc_val, exc_tb) self.db_api.__exit__(exc_type, exc_val, exc_tb)
self.notif_manager.__exit__(exc_type, exc_val, exc_tb) self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
# Give services time to fully shut down
# This prevents event loop issues where services haven't fully cleaned up
# before the next test starts
await asyncio.sleep(0.5)
def setup_dependency_overrides(self): def setup_dependency_overrides(self):
# Override get_user_id for testing # Override get_user_id for testing
self.agent_server.set_test_dependency_overrides( self.agent_server.set_test_dependency_overrides(

View File

@@ -0,0 +1,53 @@
-- CreateEnum
CREATE TYPE "WaitlistExternalStatus" AS ENUM ('DONE', 'NOT_STARTED', 'CANCELED', 'WORK_IN_PROGRESS');
-- AlterEnum
ALTER TYPE "NotificationType" ADD VALUE 'WAITLIST_LAUNCH';
-- CreateTable
CREATE TABLE "WaitlistEntry" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"storeListingId" TEXT,
"owningUserId" TEXT NOT NULL,
"slug" TEXT NOT NULL,
"search" tsvector DEFAULT ''::tsvector,
"name" TEXT NOT NULL,
"subHeading" TEXT NOT NULL,
"videoUrl" TEXT,
"agentOutputDemoUrl" TEXT,
"imageUrls" TEXT[],
"description" TEXT NOT NULL,
"categories" TEXT[],
"status" "WaitlistExternalStatus" NOT NULL DEFAULT 'NOT_STARTED',
"votes" INTEGER NOT NULL DEFAULT 0,
"unaffiliatedEmailUsers" TEXT[] DEFAULT ARRAY[]::TEXT[],
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
CONSTRAINT "WaitlistEntry_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "_joinedWaitlists" (
"A" TEXT NOT NULL,
"B" TEXT NOT NULL
);
-- CreateIndex
CREATE UNIQUE INDEX "_joinedWaitlists_AB_unique" ON "_joinedWaitlists"("A", "B");
-- CreateIndex
CREATE INDEX "_joinedWaitlists_B_index" ON "_joinedWaitlists"("B");
-- AddForeignKey
ALTER TABLE "WaitlistEntry" ADD CONSTRAINT "WaitlistEntry_storeListingId_fkey" FOREIGN KEY ("storeListingId") REFERENCES "StoreListing"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "WaitlistEntry" ADD CONSTRAINT "WaitlistEntry_owningUserId_fkey" FOREIGN KEY ("owningUserId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "_joinedWaitlists" ADD CONSTRAINT "_joinedWaitlists_A_fkey" FOREIGN KEY ("A") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "_joinedWaitlists" ADD CONSTRAINT "_joinedWaitlists_B_fkey" FOREIGN KEY ("B") REFERENCES "WaitlistEntry"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -1,37 +1,11 @@
-- CreateExtension -- CreateExtension
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first -- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
-- Ensures vector extension is in the current schema (from DATABASE_URL ?schema= param) -- Create in public schema so vector type is available across all schemas
-- If it exists in a different schema (e.g., public), we drop and recreate it in the current schema
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
DO $$ DO $$
DECLARE
current_schema_name text;
vector_schema text;
BEGIN BEGIN
-- Get the current schema from search_path CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
SELECT current_schema() INTO current_schema_name; EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'vector extension not available or already exists, skipping';
-- Check if vector extension exists and which schema it's in
SELECT n.nspname INTO vector_schema
FROM pg_extension e
JOIN pg_namespace n ON e.extnamespace = n.oid
WHERE e.extname = 'vector';
-- Handle removal if in wrong schema
IF vector_schema IS NOT NULL AND vector_schema != current_schema_name THEN
BEGIN
-- Vector exists in a different schema, drop it first
RAISE WARNING 'pgvector found in schema "%" but need it in "%". Dropping and reinstalling...',
vector_schema, current_schema_name;
EXECUTE 'DROP EXTENSION IF EXISTS vector CASCADE';
EXCEPTION WHEN OTHERS THEN
RAISE EXCEPTION 'Failed to drop pgvector from schema "%": %. You may need to drop it manually.',
vector_schema, SQLERRM;
END;
END IF;
-- Create extension in current schema (let it fail naturally if not available)
EXECUTE format('CREATE EXTENSION IF NOT EXISTS vector SCHEMA %I', current_schema_name);
END $$; END $$;
-- CreateEnum -- CreateEnum
@@ -45,7 +19,7 @@ CREATE TABLE "UnifiedContentEmbedding" (
"contentType" "ContentType" NOT NULL, "contentType" "ContentType" NOT NULL,
"contentId" TEXT NOT NULL, "contentId" TEXT NOT NULL,
"userId" TEXT, "userId" TEXT,
"embedding" vector(1536) NOT NULL, "embedding" public.vector(1536) NOT NULL,
"searchableText" TEXT NOT NULL, "searchableText" TEXT NOT NULL,
"metadata" JSONB NOT NULL DEFAULT '{}', "metadata" JSONB NOT NULL DEFAULT '{}',
@@ -71,4 +45,4 @@ CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" O
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py -- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
-- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW) -- Note: Drop first in case Prisma created a btree index (Prisma doesn't support HNSW)
DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx"; DROP INDEX IF EXISTS "UnifiedContentEmbedding_embedding_idx";
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" vector_cosine_ops); CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);

View File

@@ -0,0 +1,71 @@
-- Acknowledge Supabase-managed extensions to prevent drift warnings
-- These extensions are pre-installed by Supabase in specific schemas
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
-- Create schemas (safe in both CI and Supabase)
CREATE SCHEMA IF NOT EXISTS "extensions";
-- Extensions that exist in both CI and Supabase
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgcrypto extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'uuid-ossp extension not available, skipping';
END $$;
-- Supabase-specific extensions (skip gracefully in CI)
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_net extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgjwt extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "graphql";
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pg_graphql extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "pgsodium";
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'pgsodium extension not available, skipping';
END $$;
DO $$
BEGIN
CREATE SCHEMA IF NOT EXISTS "vault";
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
EXCEPTION WHEN OTHERS THEN
RAISE NOTICE 'supabase_vault extension not available, skipping';
END $$;
-- Return to platform
CREATE SCHEMA IF NOT EXISTS "platform";

View File

@@ -1,7 +0,0 @@
-- Remove NodeExecution foreign key from PendingHumanReview
-- The nodeExecId column remains as the primary key, but we remove the FK constraint
-- to AgentNodeExecution since PendingHumanReview records can persist after node
-- execution records are deleted.
-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey";

View File

@@ -4204,14 +4204,14 @@ strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
[[package]] [[package]]
name = "posthog" name = "posthog"
version = "7.6.0" version = "6.1.1"
description = "Integrate PostHog into any python application." description = "Integrate PostHog into any python application."
optional = false optional = false
python-versions = ">=3.10" python-versions = ">=3.9"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "posthog-7.6.0-py3-none-any.whl", hash = "sha256:c4dd78cf77c4fecceb965f86066e5ac37886ef867d68ffe75a1db5d681d7d9ad"}, {file = "posthog-6.1.1-py3-none-any.whl", hash = "sha256:329fd3d06b4d54cec925f47235bd8e327c91403c2f9ec38f1deb849535934dba"},
{file = "posthog-7.6.0.tar.gz", hash = "sha256:941dfd278ee427c9b14640f09b35b5bb52a71bdf028d7dbb7307e1838fd3002e"}, {file = "posthog-6.1.1.tar.gz", hash = "sha256:b453f54c4a2589da859fd575dd3bf86fcb40580727ec399535f268b1b9f318b8"},
] ]
[package.dependencies] [package.dependencies]
@@ -4225,7 +4225,7 @@ typing-extensions = ">=4.2.0"
[package.extras] [package.extras]
dev = ["django-stubs", "lxml", "mypy", "mypy-baseline", "packaging", "pre-commit", "pydantic", "ruff", "setuptools", "tomli", "tomli_w", "twine", "types-mock", "types-python-dateutil", "types-requests", "types-setuptools", "types-six", "wheel"] dev = ["django-stubs", "lxml", "mypy", "mypy-baseline", "packaging", "pre-commit", "pydantic", "ruff", "setuptools", "tomli", "tomli_w", "twine", "types-mock", "types-python-dateutil", "types-requests", "types-setuptools", "types-six", "wheel"]
langchain = ["langchain (>=0.2.0)"] langchain = ["langchain (>=0.2.0)"]
test = ["anthropic (>=0.72)", "coverage", "django", "freezegun (==1.5.1)", "google-genai", "langchain-anthropic (>=1.0)", "langchain-community (>=0.4)", "langchain-core (>=1.0)", "langchain-openai (>=1.0)", "langgraph (>=1.0)", "mock (>=2.0.0)", "openai (>=2.0)", "parameterized (>=0.8.1)", "pydantic", "pytest", "pytest-asyncio", "pytest-timeout"] test = ["anthropic", "coverage", "django", "freezegun (==1.5.1)", "google-genai", "langchain-anthropic (>=0.3.15)", "langchain-community (>=0.3.25)", "langchain-core (>=0.3.65)", "langchain-openai (>=0.3.22)", "langgraph (>=0.4.8)", "mock (>=2.0.0)", "openai", "parameterized (>=0.8.1)", "pydantic", "pytest", "pytest-asyncio", "pytest-timeout"]
[[package]] [[package]]
name = "postmarker" name = "postmarker"
@@ -7512,4 +7512,4 @@ cffi = ["cffi (>=1.11)"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = ">=3.10,<3.14" python-versions = ">=3.10,<3.14"
content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf" content-hash = "18b92e09596298c82432e4d0a85cb6d80a40b4229bee0a0c15f0529fd6cb21a4"

View File

@@ -85,7 +85,6 @@ exa-py = "^1.14.20"
croniter = "^6.0.0" croniter = "^6.0.0"
stagehand = "^0.5.1" stagehand = "^0.5.1"
gravitas-md2gdocs = "^0.1.0" gravitas-md2gdocs = "^0.1.0"
posthog = "^7.6.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1" aiohappyeyeballs = "^2.6.1"

View File

@@ -69,6 +69,10 @@ model User {
OAuthAuthorizationCodes OAuthAuthorizationCode[] OAuthAuthorizationCodes OAuthAuthorizationCode[]
OAuthAccessTokens OAuthAccessToken[] OAuthAccessTokens OAuthAccessToken[]
OAuthRefreshTokens OAuthRefreshToken[] OAuthRefreshTokens OAuthRefreshToken[]
// Waitlist relations
waitlistEntries WaitlistEntry[]
joinedWaitlists WaitlistEntry[] @relation("joinedWaitlists")
} }
enum OnboardingStep { enum OnboardingStep {
@@ -295,6 +299,7 @@ enum NotificationType {
REFUND_PROCESSED REFUND_PROCESSED
AGENT_APPROVED AGENT_APPROVED
AGENT_REJECTED AGENT_REJECTED
WAITLIST_LAUNCH
} }
model NotificationEvent { model NotificationEvent {
@@ -517,6 +522,8 @@ model AgentNodeExecution {
stats Json? stats Json?
PendingHumanReview PendingHumanReview?
@@index([agentGraphExecutionId, agentNodeId, executionStatus]) @@index([agentGraphExecutionId, agentNodeId, executionStatus])
@@index([agentNodeId, executionStatus]) @@index([agentNodeId, executionStatus])
@@index([addedTime, queuedTime]) @@index([addedTime, queuedTime])
@@ -565,7 +572,6 @@ enum ReviewStatus {
} }
// Pending human reviews for Human-in-the-loop blocks // Pending human reviews for Human-in-the-loop blocks
// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}")
model PendingHumanReview { model PendingHumanReview {
nodeExecId String @id nodeExecId String @id
userId String userId String
@@ -584,6 +590,7 @@ model PendingHumanReview {
reviewedAt DateTime? reviewedAt DateTime?
User User @relation(fields: [userId], references: [id], onDelete: Cascade) User User @relation(fields: [userId], references: [id], onDelete: Cascade)
NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade)
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade) GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
@@unique([nodeExecId]) // One pending review per node execution @@unique([nodeExecId]) // One pending review per node execution
@@ -899,7 +906,8 @@ model StoreListing {
OwningUser User @relation(fields: [owningUserId], references: [id]) OwningUser User @relation(fields: [owningUserId], references: [id])
// Relations // Relations
Versions StoreListingVersion[] @relation("ListingVersions") Versions StoreListingVersion[] @relation("ListingVersions")
waitlistEntries WaitlistEntry[]
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has. // Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
@@unique([agentGraphId]) @@unique([agentGraphId])
@@ -1031,6 +1039,47 @@ model StoreListingReview {
@@index([reviewByUserId]) @@index([reviewByUserId])
} }
enum WaitlistExternalStatus {
DONE
NOT_STARTED
CANCELED
WORK_IN_PROGRESS
}
model WaitlistEntry {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
storeListingId String?
StoreListing StoreListing? @relation(fields: [storeListingId], references: [id], onDelete: SetNull)
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
slug String
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
// Content fields
name String
subHeading String
videoUrl String?
agentOutputDemoUrl String?
imageUrls String[]
description String
categories String[]
//Waitlist specific fields
status WaitlistExternalStatus @default(NOT_STARTED)
votes Int @default(0) // Hide from frontend api
joinedUsers User[] @relation("joinedWaitlists")
// NOTE: DO NOT DOUBLE SEND TO THESE USERS, IF THEY HAVE SIGNED UP SINCE THEY MAY HAVE ALREADY RECEIVED AN EMAIL
// DOUBLE CHECK WHEN SENDING THAT THEY ARE NOT IN THE JOINED USERS LIST ALSO
unaffiliatedEmailUsers String[] @default([])
isDeleted Boolean @default(false)
}
enum SubmissionStatus { enum SubmissionStatus {
DRAFT // Being prepared, not yet submitted DRAFT // Being prepared, not yet submitted
PENDING // Submitted, awaiting review PENDING // Submitted, awaiting review

View File

@@ -34,10 +34,7 @@ logger = logging.getLogger(__name__)
# Default output directory relative to repo root # Default output directory relative to repo root
DEFAULT_OUTPUT_DIR = ( DEFAULT_OUTPUT_DIR = (
Path(__file__).parent.parent.parent.parent Path(__file__).parent.parent.parent.parent / "docs" / "integrations"
/ "docs"
/ "integrations"
/ "block-integrations"
) )
@@ -369,12 +366,12 @@ def generate_block_markdown(
lines.append("") lines.append("")
# What it is (full description) # What it is (full description)
lines.append("### What it is") lines.append(f"### What it is")
lines.append(block.description or "No description available.") lines.append(block.description or "No description available.")
lines.append("") lines.append("")
# How it works (manual section) # How it works (manual section)
lines.append("### How it works") lines.append(f"### How it works")
how_it_works = manual_content.get( how_it_works = manual_content.get(
"how_it_works", "_Add technical explanation here._" "how_it_works", "_Add technical explanation here._"
) )
@@ -386,7 +383,7 @@ def generate_block_markdown(
# Inputs table (auto-generated) # Inputs table (auto-generated)
visible_inputs = [f for f in block.inputs if not f.hidden] visible_inputs = [f for f in block.inputs if not f.hidden]
if visible_inputs: if visible_inputs:
lines.append("### Inputs") lines.append(f"### Inputs")
lines.append("") lines.append("")
lines.append("| Input | Description | Type | Required |") lines.append("| Input | Description | Type | Required |")
lines.append("|-------|-------------|------|----------|") lines.append("|-------|-------------|------|----------|")
@@ -403,7 +400,7 @@ def generate_block_markdown(
# Outputs table (auto-generated) # Outputs table (auto-generated)
visible_outputs = [f for f in block.outputs if not f.hidden] visible_outputs = [f for f in block.outputs if not f.hidden]
if visible_outputs: if visible_outputs:
lines.append("### Outputs") lines.append(f"### Outputs")
lines.append("") lines.append("")
lines.append("| Output | Description | Type |") lines.append("| Output | Description | Type |")
lines.append("|--------|-------------|------|") lines.append("|--------|-------------|------|")
@@ -417,21 +414,13 @@ def generate_block_markdown(
lines.append("") lines.append("")
# Possible use case (manual section) # Possible use case (manual section)
lines.append("### Possible use case") lines.append(f"### Possible use case")
use_case = manual_content.get("use_case", "_Add practical use case examples here._") use_case = manual_content.get("use_case", "_Add practical use case examples here._")
lines.append("<!-- MANUAL: use_case -->") lines.append("<!-- MANUAL: use_case -->")
lines.append(use_case) lines.append(use_case)
lines.append("<!-- END MANUAL -->") lines.append("<!-- END MANUAL -->")
lines.append("") lines.append("")
# Optional per-block extras (only include if has content)
extras = manual_content.get("extras", "")
if extras:
lines.append("<!-- MANUAL: extras -->")
lines.append(extras)
lines.append("<!-- END MANUAL -->")
lines.append("")
lines.append("---") lines.append("---")
lines.append("") lines.append("")
@@ -467,52 +456,25 @@ def get_block_file_mapping(blocks: list[BlockDoc]) -> dict[str, list[BlockDoc]]:
return dict(file_mapping) return dict(file_mapping)
def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "") -> str: def generate_overview_table(blocks: list[BlockDoc]) -> str:
"""Generate the overview table markdown (blocks.md). """Generate the overview table markdown (blocks.md)."""
Args:
blocks: List of block documentation objects
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
"""
lines = [] lines = []
# GitBook YAML frontmatter
lines.append("---")
lines.append("layout:")
lines.append(" width: default")
lines.append(" title:")
lines.append(" visible: true")
lines.append(" description:")
lines.append(" visible: true")
lines.append(" tableOfContents:")
lines.append(" visible: false")
lines.append(" outline:")
lines.append(" visible: true")
lines.append(" pagination:")
lines.append(" visible: true")
lines.append(" metadata:")
lines.append(" visible: true")
lines.append("---")
lines.append("")
lines.append("# AutoGPT Blocks Overview") lines.append("# AutoGPT Blocks Overview")
lines.append("") lines.append("")
lines.append( lines.append(
'AutoGPT uses a modular approach with various "blocks" to handle different tasks. These blocks are the building blocks of AutoGPT workflows, allowing users to create complex automations by combining simple, specialized components.' 'AutoGPT uses a modular approach with various "blocks" to handle different tasks. These blocks are the building blocks of AutoGPT workflows, allowing users to create complex automations by combining simple, specialized components.'
) )
lines.append("") lines.append("")
lines.append('{% hint style="info" %}') lines.append('!!! info "Creating Your Own Blocks"')
lines.append("**Creating Your Own Blocks**") lines.append(" Want to create your own custom blocks? Check out our guides:")
lines.append("") lines.append(" ")
lines.append("Want to create your own custom blocks? Check out our guides:")
lines.append("")
lines.append( lines.append(
"* [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples" " - [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
) )
lines.append( lines.append(
"* [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration" " - [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
) )
lines.append("{% endhint %}")
lines.append("") lines.append("")
lines.append( lines.append(
"Below is a comprehensive list of all available blocks, categorized by their primary function. Click on any block name to view its detailed documentation." "Below is a comprehensive list of all available blocks, categorized by their primary function. Click on any block name to view its detailed documentation."
@@ -575,8 +537,7 @@ def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "")
else "No description" else "No description"
) )
short_desc = short_desc.replace("\n", " ").replace("|", "\\|") short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
link_path = f"{block_dir_prefix}{file_path}" lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
lines.append("") lines.append("")
continue continue
@@ -602,55 +563,13 @@ def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "")
) )
short_desc = short_desc.replace("\n", " ").replace("|", "\\|") short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
link_path = f"{block_dir_prefix}{file_path}" lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
lines.append("") lines.append("")
return "\n".join(lines) return "\n".join(lines)
def generate_summary_md(
blocks: list[BlockDoc], root_dir: Path, block_dir_prefix: str = ""
) -> str:
"""Generate SUMMARY.md for GitBook navigation.
Args:
blocks: List of block documentation objects
root_dir: The root docs directory (e.g., docs/integrations/)
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
"""
lines = []
lines.append("# Table of contents")
lines.append("")
lines.append("* [AutoGPT Blocks Overview](README.md)")
lines.append("")
# Check for guides/ directory at the root level (docs/integrations/guides/)
guides_dir = root_dir / "guides"
if guides_dir.exists():
lines.append("## Guides")
lines.append("")
for guide_file in sorted(guides_dir.glob("*.md")):
# Use just the file name for title (replace hyphens/underscores with spaces)
title = file_path_to_title(guide_file.stem.replace("-", "_") + ".md")
lines.append(f"* [{title}](guides/{guide_file.name})")
lines.append("")
lines.append("## Block Integrations")
lines.append("")
file_mapping = get_block_file_mapping(blocks)
for file_path in sorted(file_mapping.keys()):
title = file_path_to_title(file_path)
link_path = f"{block_dir_prefix}{file_path}"
lines.append(f"* [{title}]({link_path})")
lines.append("")
return "\n".join(lines)
def load_all_blocks_for_docs() -> list[BlockDoc]: def load_all_blocks_for_docs() -> list[BlockDoc]:
"""Load all blocks and extract documentation.""" """Load all blocks and extract documentation."""
from backend.blocks import load_all_blocks from backend.blocks import load_all_blocks
@@ -734,16 +653,6 @@ def write_block_docs(
) )
) )
# Add file-level additional_content section if present
file_additional = extract_manual_content(existing_content).get(
"additional_content", ""
)
if file_additional:
content_parts.append("<!-- MANUAL: additional_content -->")
content_parts.append(file_additional)
content_parts.append("<!-- END MANUAL -->")
content_parts.append("")
full_content = file_header + "\n" + "\n".join(content_parts) full_content = file_header + "\n" + "\n".join(content_parts)
generated_files[str(file_path)] = full_content generated_files[str(file_path)] = full_content
@@ -752,28 +661,14 @@ def write_block_docs(
full_path.write_text(full_content) full_path.write_text(full_content)
# Generate overview file at the parent directory (docs/integrations/) # Generate overview file
# with links prefixed to point into block-integrations/ overview_content = generate_overview_table(blocks)
root_dir = output_dir.parent overview_path = output_dir / "README.md"
block_dir_name = output_dir.name # "block-integrations"
block_dir_prefix = f"{block_dir_name}/"
overview_content = generate_overview_table(blocks, block_dir_prefix)
overview_path = root_dir / "README.md"
generated_files["README.md"] = overview_content generated_files["README.md"] = overview_content
overview_path.write_text(overview_content) overview_path.write_text(overview_content)
if verbose: if verbose:
print(" Writing README.md (overview) to parent directory") print(" Writing README.md (overview)")
# Generate SUMMARY.md for GitBook navigation at the parent directory
summary_content = generate_summary_md(blocks, root_dir, block_dir_prefix)
summary_path = root_dir / "SUMMARY.md"
generated_files["SUMMARY.md"] = summary_content
summary_path.write_text(summary_content)
if verbose:
print(" Writing SUMMARY.md (navigation) to parent directory")
return generated_files return generated_files
@@ -853,16 +748,6 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
elif block_match.group(1).strip() != expected_block_content.strip(): elif block_match.group(1).strip() != expected_block_content.strip():
mismatched_blocks.append(block.name) mismatched_blocks.append(block.name)
# Add file-level additional_content to expected content (matches write_block_docs)
file_additional = extract_manual_content(existing_content).get(
"additional_content", ""
)
if file_additional:
content_parts.append("<!-- MANUAL: additional_content -->")
content_parts.append(file_additional)
content_parts.append("<!-- END MANUAL -->")
content_parts.append("")
expected_content = file_header + "\n" + "\n".join(content_parts) expected_content = file_header + "\n" + "\n".join(content_parts)
if existing_content.strip() != expected_content.strip(): if existing_content.strip() != expected_content.strip():
@@ -872,15 +757,11 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
out_of_sync_details.append((file_path, mismatched_blocks)) out_of_sync_details.append((file_path, mismatched_blocks))
all_match = False all_match = False
# Check overview at the parent directory (docs/integrations/) # Check overview
root_dir = output_dir.parent overview_path = output_dir / "README.md"
block_dir_name = output_dir.name # "block-integrations"
block_dir_prefix = f"{block_dir_name}/"
overview_path = root_dir / "README.md"
if overview_path.exists(): if overview_path.exists():
existing_overview = overview_path.read_text() existing_overview = overview_path.read_text()
expected_overview = generate_overview_table(blocks, block_dir_prefix) expected_overview = generate_overview_table(blocks)
if existing_overview.strip() != expected_overview.strip(): if existing_overview.strip() != expected_overview.strip():
print("OUT OF SYNC: README.md (overview)") print("OUT OF SYNC: README.md (overview)")
print(" The blocks overview table needs regeneration") print(" The blocks overview table needs regeneration")
@@ -891,21 +772,6 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
out_of_sync_details.append(("README.md", ["overview table"])) out_of_sync_details.append(("README.md", ["overview table"]))
all_match = False all_match = False
# Check SUMMARY.md at the parent directory
summary_path = root_dir / "SUMMARY.md"
if summary_path.exists():
existing_summary = summary_path.read_text()
expected_summary = generate_summary_md(blocks, root_dir, block_dir_prefix)
if existing_summary.strip() != expected_summary.strip():
print("OUT OF SYNC: SUMMARY.md (navigation)")
print(" The GitBook navigation needs regeneration")
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
all_match = False
else:
print("MISSING: SUMMARY.md (navigation)")
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
all_match = False
# Check for unfilled manual sections # Check for unfilled manual sections
unfilled_patterns = [ unfilled_patterns = [
"_Add a description of this category of blocks._", "_Add a description of this category of blocks._",

View File

@@ -11,7 +11,6 @@
"forked_from_version": null, "forked_from_version": null,
"has_external_trigger": false, "has_external_trigger": false,
"has_human_in_the_loop": false, "has_human_in_the_loop": false,
"has_sensitive_action": false,
"id": "graph-123", "id": "graph-123",
"input_schema": { "input_schema": {
"properties": {}, "properties": {},

View File

@@ -11,7 +11,6 @@
"forked_from_version": null, "forked_from_version": null,
"has_external_trigger": false, "has_external_trigger": false,
"has_human_in_the_loop": false, "has_human_in_the_loop": false,
"has_sensitive_action": false,
"id": "graph-123", "id": "graph-123",
"input_schema": { "input_schema": {
"properties": {}, "properties": {},

View File

@@ -27,8 +27,6 @@
"properties": {} "properties": {}
}, },
"has_external_trigger": false, "has_external_trigger": false,
"has_human_in_the_loop": false,
"has_sensitive_action": false,
"trigger_setup_info": null, "trigger_setup_info": null,
"new_output": false, "new_output": false,
"can_access_graph": true, "can_access_graph": true,
@@ -36,8 +34,7 @@
"is_favorite": false, "is_favorite": false,
"recommended_schedule_cron": null, "recommended_schedule_cron": null,
"settings": { "settings": {
"human_in_the_loop_safe_mode": true, "human_in_the_loop_safe_mode": null
"sensitive_action_safe_mode": false
}, },
"marketplace_listing": null "marketplace_listing": null
}, },
@@ -68,8 +65,6 @@
"properties": {} "properties": {}
}, },
"has_external_trigger": false, "has_external_trigger": false,
"has_human_in_the_loop": false,
"has_sensitive_action": false,
"trigger_setup_info": null, "trigger_setup_info": null,
"new_output": false, "new_output": false,
"can_access_graph": false, "can_access_graph": false,
@@ -77,8 +72,7 @@
"is_favorite": false, "is_favorite": false,
"recommended_schedule_cron": null, "recommended_schedule_cron": null,
"settings": { "settings": {
"human_in_the_loop_safe_mode": true, "human_in_the_loop_safe_mode": null
"sensitive_action_safe_mode": false
}, },
"marketplace_listing": null "marketplace_listing": null
} }

View File

@@ -1 +0,0 @@
"""Tests for agent generator module."""

View File

@@ -1,273 +0,0 @@
"""
Tests for the Agent Generator core module.
This test suite verifies that the core functions correctly delegate to
the external Agent Generator service.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.api.features.chat.tools.agent_generator import core
from backend.api.features.chat.tools.agent_generator.core import (
AgentGeneratorNotConfiguredError,
)
class TestServiceNotConfigured:
"""Test that functions raise AgentGeneratorNotConfiguredError when service is not configured."""
@pytest.mark.asyncio
async def test_decompose_goal_raises_when_not_configured(self):
"""Test that decompose_goal raises error when service not configured."""
with patch.object(core, "is_external_service_configured", return_value=False):
with pytest.raises(AgentGeneratorNotConfiguredError):
await core.decompose_goal("Build a chatbot")
@pytest.mark.asyncio
async def test_generate_agent_raises_when_not_configured(self):
"""Test that generate_agent raises error when service not configured."""
with patch.object(core, "is_external_service_configured", return_value=False):
with pytest.raises(AgentGeneratorNotConfiguredError):
await core.generate_agent({"steps": []})
@pytest.mark.asyncio
async def test_generate_agent_patch_raises_when_not_configured(self):
"""Test that generate_agent_patch raises error when service not configured."""
with patch.object(core, "is_external_service_configured", return_value=False):
with pytest.raises(AgentGeneratorNotConfiguredError):
await core.generate_agent_patch("Add a node", {"nodes": []})
class TestDecomposeGoal:
"""Test decompose_goal function service delegation."""
@pytest.mark.asyncio
async def test_calls_external_service(self):
"""Test that decompose_goal calls the external service."""
expected_result = {"type": "instructions", "steps": ["Step 1"]}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "decompose_goal_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
result = await core.decompose_goal("Build a chatbot")
mock_external.assert_called_once_with("Build a chatbot", "")
assert result == expected_result
@pytest.mark.asyncio
async def test_passes_context_to_external_service(self):
"""Test that decompose_goal passes context to external service."""
expected_result = {"type": "instructions", "steps": ["Step 1"]}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "decompose_goal_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
await core.decompose_goal("Build a chatbot", "Use Python")
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
@pytest.mark.asyncio
async def test_returns_none_on_service_failure(self):
"""Test that decompose_goal returns None when external service fails."""
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "decompose_goal_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = None
result = await core.decompose_goal("Build a chatbot")
assert result is None
class TestGenerateAgent:
"""Test generate_agent function service delegation."""
@pytest.mark.asyncio
async def test_calls_external_service(self):
"""Test that generate_agent calls the external service."""
expected_result = {"name": "Test Agent", "nodes": [], "links": []}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await core.generate_agent(instructions)
mock_external.assert_called_once_with(instructions)
# Result should have id, version, is_active added if not present
assert result is not None
assert result["name"] == "Test Agent"
assert "id" in result
assert result["version"] == 1
assert result["is_active"] is True
@pytest.mark.asyncio
async def test_preserves_existing_id_and_version(self):
"""Test that external service result preserves existing id and version."""
expected_result = {
"id": "existing-id",
"version": 3,
"is_active": False,
"name": "Test Agent",
}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result.copy()
result = await core.generate_agent({"steps": []})
assert result is not None
assert result["id"] == "existing-id"
assert result["version"] == 3
assert result["is_active"] is False
@pytest.mark.asyncio
async def test_returns_none_when_external_service_fails(self):
"""Test that generate_agent returns None when external service fails."""
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = None
result = await core.generate_agent({"steps": []})
assert result is None
class TestGenerateAgentPatch:
"""Test generate_agent_patch function service delegation."""
@pytest.mark.asyncio
async def test_calls_external_service(self):
"""Test that generate_agent_patch calls the external service."""
expected_result = {"name": "Updated Agent", "nodes": [], "links": []}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_patch_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
current_agent = {"nodes": [], "links": []}
result = await core.generate_agent_patch("Add a node", current_agent)
mock_external.assert_called_once_with("Add a node", current_agent)
assert result == expected_result
@pytest.mark.asyncio
async def test_returns_clarifying_questions(self):
"""Test that generate_agent_patch returns clarifying questions."""
expected_result = {
"type": "clarifying_questions",
"questions": [{"question": "What type of node?"}],
}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_patch_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
result = await core.generate_agent_patch("Add a node", {"nodes": []})
assert result == expected_result
@pytest.mark.asyncio
async def test_returns_none_when_external_service_fails(self):
"""Test that generate_agent_patch returns None when service fails."""
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_patch_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = None
result = await core.generate_agent_patch("Add a node", {"nodes": []})
assert result is None
class TestJsonToGraph:
"""Test json_to_graph function."""
def test_converts_agent_json_to_graph(self):
"""Test conversion of agent JSON to Graph model."""
agent_json = {
"id": "test-id",
"version": 2,
"is_active": True,
"name": "Test Agent",
"description": "A test agent",
"nodes": [
{
"id": "node1",
"block_id": "block1",
"input_default": {"key": "value"},
"metadata": {"x": 100},
}
],
"links": [
{
"id": "link1",
"source_id": "node1",
"sink_id": "output",
"source_name": "result",
"sink_name": "input",
"is_static": False,
}
],
}
graph = core.json_to_graph(agent_json)
assert graph.id == "test-id"
assert graph.version == 2
assert graph.is_active is True
assert graph.name == "Test Agent"
assert graph.description == "A test agent"
assert len(graph.nodes) == 1
assert graph.nodes[0].id == "node1"
assert graph.nodes[0].block_id == "block1"
assert len(graph.links) == 1
assert graph.links[0].source_id == "node1"
def test_generates_ids_if_missing(self):
"""Test that missing IDs are generated."""
agent_json = {
"name": "Test Agent",
"nodes": [{"block_id": "block1"}],
"links": [],
}
graph = core.json_to_graph(agent_json)
assert graph.id is not None
assert graph.nodes[0].id is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,422 +0,0 @@
"""
Tests for the Agent Generator external service client.
This test suite verifies the external Agent Generator service integration,
including service detection, API calls, and error handling.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from backend.api.features.chat.tools.agent_generator import service
class TestServiceConfiguration:
"""Test service configuration detection."""
def setup_method(self):
"""Reset settings singleton before each test."""
service._settings = None
service._client = None
def test_external_service_not_configured_when_host_empty(self):
"""Test that external service is not configured when host is empty."""
mock_settings = MagicMock()
mock_settings.config.agentgenerator_host = ""
with patch.object(service, "_get_settings", return_value=mock_settings):
assert service.is_external_service_configured() is False
def test_external_service_configured_when_host_set(self):
"""Test that external service is configured when host is set."""
mock_settings = MagicMock()
mock_settings.config.agentgenerator_host = "agent-generator.local"
with patch.object(service, "_get_settings", return_value=mock_settings):
assert service.is_external_service_configured() is True
def test_get_base_url(self):
"""Test base URL construction."""
mock_settings = MagicMock()
mock_settings.config.agentgenerator_host = "agent-generator.local"
mock_settings.config.agentgenerator_port = 8000
with patch.object(service, "_get_settings", return_value=mock_settings):
url = service._get_base_url()
assert url == "http://agent-generator.local:8000"
class TestDecomposeGoalExternal:
"""Test decompose_goal_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_decompose_goal_returns_instructions(self):
"""Test successful decomposition returning instructions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1", "Step 2"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
mock_client.post.assert_called_once_with(
"/api/decompose-description", json={"description": "Build a chatbot"}
)
@pytest.mark.asyncio
async def test_decompose_goal_returns_clarifying_questions(self):
"""Test decomposition returning clarifying questions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "clarifying_questions",
"questions": ["What platform?", "What language?"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build something")
assert result == {
"type": "clarifying_questions",
"questions": ["What platform?", "What language?"],
}
@pytest.mark.asyncio
async def test_decompose_goal_with_context(self):
"""Test decomposition with additional context."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
await service.decompose_goal_external(
"Build a chatbot", context="Use Python"
)
mock_client.post.assert_called_once_with(
"/api/decompose-description",
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
)
@pytest.mark.asyncio
async def test_decompose_goal_returns_unachievable_goal(self):
"""Test decomposition returning unachievable goal response."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "unachievable_goal",
"reason": "Cannot do X",
"suggested_goal": "Try Y instead",
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Do something impossible")
assert result == {
"type": "unachievable_goal",
"reason": "Cannot do X",
"suggested_goal": "Try Y instead",
}
@pytest.mark.asyncio
async def test_decompose_goal_handles_http_error(self):
"""Test decomposition handles HTTP errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.HTTPStatusError(
"Server error", request=MagicMock(), response=MagicMock()
)
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result is None
@pytest.mark.asyncio
async def test_decompose_goal_handles_request_error(self):
"""Test decomposition handles request errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result is None
@pytest.mark.asyncio
async def test_decompose_goal_handles_service_error(self):
"""Test decomposition handles service returning error."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": False,
"error": "Internal error",
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result is None
class TestGenerateAgentExternal:
"""Test generate_agent_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_generate_agent_success(self):
"""Test successful agent generation."""
agent_json = {
"name": "Test Agent",
"nodes": [],
"links": [],
}
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": agent_json,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
instructions = {"type": "instructions", "steps": ["Step 1"]}
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_external(instructions)
assert result == agent_json
mock_client.post.assert_called_once_with(
"/api/generate-agent", json={"instructions": instructions}
)
@pytest.mark.asyncio
async def test_generate_agent_handles_error(self):
"""Test agent generation handles errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_external({"steps": []})
assert result is None
class TestGenerateAgentPatchExternal:
"""Test generate_agent_patch_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_generate_patch_returns_updated_agent(self):
"""Test successful patch generation returning updated agent."""
updated_agent = {
"name": "Updated Agent",
"nodes": [{"id": "1", "block_id": "test"}],
"links": [],
}
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": updated_agent,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_patch_external(
"Add a new node", current_agent
)
assert result == updated_agent
mock_client.post.assert_called_once_with(
"/api/update-agent",
json={
"update_request": "Add a new node",
"current_agent_json": current_agent,
},
)
@pytest.mark.asyncio
async def test_generate_patch_returns_clarifying_questions(self):
"""Test patch generation returning clarifying questions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "clarifying_questions",
"questions": ["What type of node?"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_patch_external(
"Add something", {"nodes": []}
)
assert result == {
"type": "clarifying_questions",
"questions": ["What type of node?"],
}
class TestHealthCheck:
"""Test health_check function."""
def setup_method(self):
"""Reset singletons before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_health_check_returns_false_when_not_configured(self):
"""Test health check returns False when service not configured."""
with patch.object(
service, "is_external_service_configured", return_value=False
):
result = await service.health_check()
assert result is False
@pytest.mark.asyncio
async def test_health_check_returns_true_when_healthy(self):
"""Test health check returns True when service is healthy."""
mock_response = MagicMock()
mock_response.json.return_value = {
"status": "healthy",
"blocks_loaded": True,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
assert result is True
mock_client.get.assert_called_once_with("/health")
@pytest.mark.asyncio
async def test_health_check_returns_false_when_not_healthy(self):
"""Test health check returns False when service is not healthy."""
mock_response = MagicMock()
mock_response.json.return_value = {
"status": "unhealthy",
"blocks_loaded": False,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
assert result is False
@pytest.mark.asyncio
async def test_health_check_returns_false_on_error(self):
"""Test health check returns False on connection error."""
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
assert result is False
class TestGetBlocksExternal:
"""Test get_blocks_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_get_blocks_success(self):
"""Test successful blocks retrieval."""
blocks = [
{"id": "block1", "name": "Block 1"},
{"id": "block2", "name": "Block 2"},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"blocks": blocks,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.get_blocks_external()
assert result == blocks
mock_client.get.assert_called_once_with("/api/blocks")
@pytest.mark.asyncio
async def test_get_blocks_handles_error(self):
"""Test blocks retrieval handles errors gracefully."""
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.get_blocks_external()
assert result is None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -30,7 +30,3 @@ NEXT_PUBLIC_TURNSTILE=disabled
# PR previews # PR previews
NEXT_PUBLIC_PREVIEW_STEALING_DEV= NEXT_PUBLIC_PREVIEW_STEALING_DEV=
# PostHog Analytics
NEXT_PUBLIC_POSTHOG_KEY=
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com

View File

@@ -175,8 +175,6 @@ While server components and actions are cool and cutting-edge, they introduce a
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster)) - Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
- Co-locate UI state inside components/hooks; keep global state minimal - Co-locate UI state inside components/hooks; keep global state minimal
- Avoid `useMemo` and `useCallback` unless you have a measured performance issue
- Do not abuse `useEffect`; prefer state colocation and derive values directly when possible
### Styling and components ### Styling and components
@@ -551,48 +549,9 @@ Files:
Types: Types:
- Prefer `interface` for object shapes - Prefer `interface` for object shapes
- Component props should be `interface Props { ... }` (not exported) - Component props should be `interface Props { ... }`
- Only use specific exported names (e.g., `export interface MyComponentProps`) when the interface needs to be used outside the component
- Keep type definitions inline with the component - do not create separate `types.ts` files unless types are shared across multiple files
- Use precise types; avoid `any` and unsafe casts - Use precise types; avoid `any` and unsafe casts
**Props naming examples:**
```tsx
// ✅ Good - internal props, not exported
interface Props {
title: string;
onClose: () => void;
}
export function Modal({ title, onClose }: Props) {
// ...
}
// ✅ Good - exported when needed externally
export interface ModalProps {
title: string;
onClose: () => void;
}
export function Modal({ title, onClose }: ModalProps) {
// ...
}
// ❌ Bad - unnecessarily specific name for internal use
interface ModalComponentProps {
title: string;
onClose: () => void;
}
// ❌ Bad - separate types.ts file for single component
// types.ts
export interface ModalProps { ... }
// Modal.tsx
import type { ModalProps } from './types';
```
Parameters: Parameters:
- If more than one parameter is needed, pass a single `Args` object for clarity - If more than one parameter is needed, pass a single `Args` object for clarity

View File

@@ -16,12 +16,6 @@ export default defineConfig({
client: "react-query", client: "react-query",
httpClient: "fetch", httpClient: "fetch",
indexFiles: false, indexFiles: false,
mock: {
type: "msw",
baseUrl: "http://localhost:3000/api/proxy",
generateEachHttpStatus: true,
delay: 0,
},
override: { override: {
mutator: { mutator: {
path: "./mutators/custom-mutator.ts", path: "./mutators/custom-mutator.ts",

View File

@@ -15,8 +15,6 @@
"types": "tsc --noEmit", "types": "tsc --noEmit",
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test", "test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui", "test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
"test:unit": "vitest run",
"test:unit:watch": "vitest",
"test:no-build": "playwright test", "test:no-build": "playwright test",
"gentests": "playwright codegen http://localhost:3000", "gentests": "playwright codegen http://localhost:3000",
"storybook": "storybook dev -p 6006", "storybook": "storybook dev -p 6006",
@@ -34,7 +32,6 @@
"@hookform/resolvers": "5.2.2", "@hookform/resolvers": "5.2.2",
"@next/third-parties": "15.4.6", "@next/third-parties": "15.4.6",
"@phosphor-icons/react": "2.1.10", "@phosphor-icons/react": "2.1.10",
"@posthog/react": "1.7.0",
"@radix-ui/react-accordion": "1.2.12", "@radix-ui/react-accordion": "1.2.12",
"@radix-ui/react-alert-dialog": "1.1.15", "@radix-ui/react-alert-dialog": "1.1.15",
"@radix-ui/react-avatar": "1.1.10", "@radix-ui/react-avatar": "1.1.10",
@@ -92,7 +89,6 @@
"next-themes": "0.4.6", "next-themes": "0.4.6",
"nuqs": "2.7.2", "nuqs": "2.7.2",
"party-js": "2.2.0", "party-js": "2.2.0",
"posthog-js": "1.334.1",
"react": "18.3.1", "react": "18.3.1",
"react-currency-input-field": "4.0.3", "react-currency-input-field": "4.0.3",
"react-day-picker": "9.11.1", "react-day-picker": "9.11.1",
@@ -131,8 +127,6 @@
"@storybook/nextjs": "9.1.5", "@storybook/nextjs": "9.1.5",
"@tanstack/eslint-plugin-query": "5.91.2", "@tanstack/eslint-plugin-query": "5.91.2",
"@tanstack/react-query-devtools": "5.90.2", "@tanstack/react-query-devtools": "5.90.2",
"@testing-library/dom": "10.4.1",
"@testing-library/react": "16.3.2",
"@types/canvas-confetti": "1.9.0", "@types/canvas-confetti": "1.9.0",
"@types/lodash": "4.17.20", "@types/lodash": "4.17.20",
"@types/negotiator": "0.6.4", "@types/negotiator": "0.6.4",
@@ -141,7 +135,6 @@
"@types/react-dom": "18.3.5", "@types/react-dom": "18.3.5",
"@types/react-modal": "3.16.3", "@types/react-modal": "3.16.3",
"@types/react-window": "1.8.8", "@types/react-window": "1.8.8",
"@vitejs/plugin-react": "5.1.2",
"axe-playwright": "2.2.2", "axe-playwright": "2.2.2",
"chromatic": "13.3.3", "chromatic": "13.3.3",
"concurrently": "9.2.1", "concurrently": "9.2.1",
@@ -149,7 +142,6 @@
"eslint": "8.57.1", "eslint": "8.57.1",
"eslint-config-next": "15.5.7", "eslint-config-next": "15.5.7",
"eslint-plugin-storybook": "9.1.5", "eslint-plugin-storybook": "9.1.5",
"happy-dom": "20.3.4",
"import-in-the-middle": "2.0.2", "import-in-the-middle": "2.0.2",
"msw": "2.11.6", "msw": "2.11.6",
"msw-storybook-addon": "2.0.6", "msw-storybook-addon": "2.0.6",
@@ -161,9 +153,7 @@
"require-in-the-middle": "8.0.1", "require-in-the-middle": "8.0.1",
"storybook": "9.1.5", "storybook": "9.1.5",
"tailwindcss": "3.4.17", "tailwindcss": "3.4.17",
"typescript": "5.9.3", "typescript": "5.9.3"
"vite-tsconfig-paths": "6.0.4",
"vitest": "4.0.17"
}, },
"msw": { "msw": {
"workerDirectory": [ "workerDirectory": [

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

Some files were not shown because too many files have changed in this diff Show More