Compare commits

..

12 Commits

Author SHA1 Message Date
abhi1992002
cb852a947b refactor(frontend): streamline NodeDataViewer component and execution results handling
### Changes
- Removed unused `NodeExecutionResult` type and `executionResults` prop from `NodeDataViewerProps`.
- Simplified the logic for resolving execution results by directly using the `useNodeStore` hook.
- Updated the component to ensure consistent handling of data types and improved readability.

### Impact
- Enhances code clarity and maintainability by reducing unnecessary complexity in the component.
- Ensures that the latest execution results are effectively utilized in the data viewer.

### Testing
- Verified that the component functions correctly with the updated logic and maintains expected behavior.
2026-01-25 12:25:49 +05:30
Abhimanyu Yadav
10cc347563 Merge branch 'dev' into abhi/show-all-execution-node 2026-01-25 12:17:28 +05:30
abhi1992002
71c0f909f3 refactor(frontend): enhance node execution result handling in nodeStore
### Changes
- Updated the logic for handling duplicate node execution results by using `findIndex` instead of `some`.
- Improved the update mechanism for existing results to ensure that input and output data changes are accurately reflected.
- Recomputed accumulated input and output data when duplicates are found, enhancing data integrity.

### Impact
- Increases the accuracy of node execution data management, ensuring that the latest results are used effectively.
- Enhances code clarity and maintainability by streamlining the update process for node execution results.

### Testing
- Verified that the updated logic correctly handles duplicate results and maintains the integrity of accumulated data.
2026-01-25 12:17:12 +05:30
abhi1992002
9d9fea700b refactor(frontend): update imports and simplify data handling in CustomNode components
### Changes
- Updated import paths for `NodeResolutionData` to use the new `types` module.
- Simplified the data handling logic in `useNodeDataViewer` by removing unnecessary `useMemo` hooks, improving readability and performance.
- Cleaned up imports in `useSubAgentUpdateState` to align with the new structure.

### Impact
- Enhances code clarity and maintainability by reducing complexity in data processing.
- Ensures consistency in import paths across components.

### Testing
- Verified that the functionality of the affected components remains intact after the refactor.
2026-01-25 12:03:22 +05:30
abhi1992002
4e25b1d0b2 refactor(frontend): simplify NodeDataRenderer output handling
### Changes
- Removed unused `useNodeStore` and `useShallow` imports from `NodeOutput.tsx`.
- Simplified the mapping of output data in `NodeDataRenderer` by directly using `latestOutputData` values.
- Updated the `handleCopy` function to use the simplified value variable.

### Impact
- Streamlines the output rendering logic, improving code readability and maintainability.
- Reduces unnecessary dependencies, enhancing performance.

### Testing
- Verified that the output rendering remains consistent with the latest execution results.
2026-01-25 11:54:05 +05:30
abhi1992002
2cd9ec5106 refactor(frontend): update node execution result handling and output rendering
### Changes
- Refactored `AgentOutputs` and `CustomNode` components to handle multiple execution results, ensuring the latest result is used for output rendering.
- Updated `useNodeOutput` to retrieve the latest input and output data from the node store.
- Enhanced `NodeDataRenderer` and `NodeDataViewer` to support grouped execution results, improving the display of output data.
- Introduced new methods in `nodeStore` for managing accumulated input and output data, along with clearing execution results.

### Impact
- Improves the accuracy of displayed outputs by using the most recent execution results.
- Enhances user experience by providing a clearer view of execution data across multiple runs.

### Testing
- Verified that the updated components render the correct output data based on the latest execution results.
- Ensured that all existing tests pass with the new data handling logic.
2026-01-25 11:43:36 +05:30
Zamil Majdy
9a6e17ff52 feat(backend): add external Agent Generator service integration (#11819)
## Summary
- Add support for delegating agent generation to an external
microservice when `AGENTGENERATOR_HOST` is configured
- Falls back to built-in LLM-based implementation when not configured
(default behavior)
- Add comprehensive tests for the service client and core integration
(34 tests)

## Changes
- Add `agentgenerator_host`, `agentgenerator_port`,
`agentgenerator_timeout` settings to `backend/util/settings.py`
- Add `service.py` client for external Agent Generator API endpoints:
  - `/api/decompose-description` - Break down goals into steps
  - `/api/generate-agent` - Generate agent from instructions
  - `/api/update-agent` - Generate patches to update existing agents
  - `/api/blocks` - Get available blocks
  - `/health` - Health check
- Update `core.py` to delegate to external service when configured
- Export `is_external_service_configured` and
`check_external_service_health` from the module

## Related PRs
- Infrastructure repo:
https://github.com/Significant-Gravitas/AutoGPT-cloud-infrastructure/pull/273

## Test plan
- [x] All 34 new tests pass (`poetry run pytest test/agent_generator/
-v`)
- [ ] Deploy with `AGENTGENERATOR_HOST` configured and verify external
service is used
- [ ] Verify built-in implementation still works when
`AGENTGENERATOR_HOST` is empty
2026-01-25 04:08:56 +07:00
Zamil Majdy
fb58827c61 feat(backend;frontend): Implement node-specific auto-approval, safety popup, and race condition fixes (#11810)
## Summary

This PR implements comprehensive improvements to the human-in-the-loop
(HITL) review system, including safety features, architectural changes,
and bug fixes:

### Key Features
- **SECRT-1798: One-time safety popup** - Shows informational popup
before first run of AI-generated agents with sensitive actions/HITL
blocks
- **SECRT-1795: Auto-approval toggle UX** - Toggle in pending reviews
panel to auto-approve future actions from the same node
- **Node-specific auto-approval** - Changed from execution-specific to
node-specific using special key pattern
`auto_approve_{graph_exec_id}_{node_id}`
- **Consolidated approval checking** - Merged `check_auto_approval` into
`check_approval` using single OR query for better performance
- **Race condition prevention** - Added execution status check before
resuming to prevent duplicate execution when approving while graph is
running
- **Parallel auto-approval creation** - Uses `asyncio.gather` for better
performance when creating multiple auto-approval records

## Changes

### Backend Architecture
- **`human_review.py`**: 
- Added `check_approval()` function that checks both normal and
auto-approval in single query
- Added `create_auto_approval_record()` for node-specific auto-approval
using special key pattern
- Added `get_auto_approve_key()` helper to generate consistent
auto-approval keys
- **`review/routes.py`**: 
- Added execution status check before resuming to prevent race
conditions
- Refactored auto-approval record creation to use parallel execution
with `asyncio.gather`
  - Removed obvious comments for cleaner code
- **`review/model.py`**: Added `auto_approve_future_actions` field to
`ReviewRequest`
- **`blocks/helpers/review.py`**: Updated to use consolidated
`check_approval` via database manager client
- **`executor/database.py`**: Exposed `check_approval` through
DatabaseManager RPC for block execution context
- **`data/block.py`**: Fixed safe mode checks for sensitive action
blocks

### Frontend
- **New `AIAgentSafetyPopup`** component with localStorage-based
one-time display
- **`PendingReviewsList`**: 
  - Replaced "Approve all future actions" button with toggle
- Toggle resets data to original values and disables editing when
enabled
  - Shows warning message explaining auto-approval behavior
- **`RunAgentModal`**: Integrated safety popup before first run
- **`usePendingReviews`**: Added polling for real-time badge updates
- **`FloatingSafeModeToggle` & `SafeModeToggle`**: Simplified visibility
logic
- **`local-storage.ts`**: Added localStorage key for popup state
tracking

### Bug Fixes
- Fixed "Client is not connected to query engine" error by using
database manager client pattern
- Fixed race condition where approving reviews while graph is RUNNING
could queue execution twice
- Fixed migration to only drop FK constraint, not non-existent column
- Fixed card data reset when auto-approve toggle changes

### Code Quality
- Removed duplicate/obvious comments
- Moved imports to top-level instead of local scope in tests
- Used walrus operator for cleaner conditional assignments
- Parallel execution for auto-approval record creation

## Test plan
- [ ] Create an AI-generated agent with sensitive actions (e.g., email
sending)
- [ ] First run should show the safety popup before starting
- [ ] Subsequent runs should not show the popup
- [ ] Clear localStorage (`AI_AGENT_SAFETY_POPUP_SHOWN`) to verify popup
shows again
- [ ] Create an agent with human-in-the-loop blocks
- [ ] Run it and verify the pending reviews panel appears
- [ ] Enable the "Auto-approve all future actions" toggle
- [ ] Verify editing is disabled and shows warning message
- [ ] Click "Approve" and verify subsequent blocks from same node
auto-approve
- [ ] Verify auto-approval persists across multiple executions of same
graph
- [ ] Disable toggle and verify editing works normally
- [ ] Verify "Reject" button still works regardless of toggle state
- [ ] Test race condition: Approve reviews while graph is RUNNING
(should skip resume)
- [ ] Test race condition: Approve reviews while graph is REVIEW (should
resume)
- [ ] Verify pending reviews badge updates in real-time when new reviews
are created
2026-01-25 04:05:25 +07:00
Zamil Majdy
595f3508c1 refactor(backend): consolidate embedding error logging to prevent Sentry spam (#11832)
## Summary

Refactors error handling in the embedding service to prevent Sentry
alert spam. Previously, batch operations would log one error per failed
file, causing hundreds of duplicate alerts. Now, exceptions bubble up
from individual functions and are aggregated at the batch level,
producing a single log entry showing all unique error types with counts.

## Changes

### Removed Error Swallowing
- Removed try/except blocks from `generate_embedding()`,
`store_content_embedding()`, `ensure_content_embedding()`,
`get_content_embedding()`, and `ensure_embedding()`
- These functions now raise exceptions instead of returning None/False
on failure
- Added docstring notes: "Raises exceptions on failure - caller should
handle"

### Improved Batch Error Aggregation
- Updated `backfill_all_content_types()` to aggregate unique errors
- Collects all exceptions from batch results
- Groups by error type and message, shows counts
- Single log entry per content type instead of per-file

### Example Output
Before: 50 separate error logs for same issue
After: `BLOCK: 50/100 embeddings failed. Errors: PrismaError: type
vector does not exist (50x)`

## Motivation

This was triggered by the AUTOGPT-SERVER-7D2 Sentry issue where pgvector
errors created hundreds of duplicate alerts. Even after the root cause
was fixed (stale database connections), the error logging pattern would
create spam for any future issues.

## Impact

-  Reduces Sentry noise - single alert per batch instead of per-file
-  Better diagnostics - shows all unique error types with counts
-  Cleaner code - removed ~24 lines of unnecessary error swallowing
-  Proper exception propagation follows Python best practices

## Testing

- Existing tests should pass (error handling moved to batch level)
- Error aggregation logic tested via
asyncio.gather(return_exceptions=True)

## Related Issues

- Fixes Sentry alert spam from AUTOGPT-SERVER-7D2
2026-01-24 21:49:32 +07:00
Ubbe
7892590b12 feat(frontend): refine copilot loading states (#11827)
## Changes 🏗️

- Make the loading UX better when switching between chats or loading a
new chat
- Make session/chat management logic more manageable
- Improving "Deep thinking" loading states
- Fix bug that happened when returning to chat after navigating away

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run the app locally and test the above
2026-01-23 18:25:45 +07:00
Bently
82d7134fc6 feat(blocks): Add ClaudeCodeBlock for executing tasks via Claude Code in E2B sandbox (#11761)
Introduces a new ClaudeCodeBlock that enables execution of coding tasks
using Anthropic's Claude Code in an E2B sandbox. This block unlocks
powerful agentic coding capabilities - Claude Code can autonomously
create files, install packages, run commands, and build complete
applications within a secure sandboxed environment.

Changes 🏗️

- New file backend/blocks/claude_code.py:
  - ClaudeCodeBlock - Execute tasks using Claude Code in an E2B sandbox
- Dual credential support: E2B API key (sandbox) + Anthropic API key
(Claude Code)
- Session continuation support via session_id, sandbox_id, and
conversation_history
- Automatic file extraction with path, relative_path, name, and content
fields
  - Configurable timeout, setup commands, and working directory
- dispose_sandbox option to keep sandbox alive for multi-turn
conversations

Checklist 📋

For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create and execute ClaudeCodeBlock with a simple prompt ("Create a
hello world HTML file")
- [x] Verify files output includes correct path, relative_path, name,
and content
- [x] Test session continuation by passing session_id and sandbox_id
back
- [x] Build "Any API → Instant App" demo agent combining Firecrawl +
ClaudeCodeBlock + GitHub blocks
- [x] Verify generated files are pushed to GitHub with correct folder
structure using relative_path

Here are two example agents i made that can be used to test this agent,
they require github, anthropic and e2b access via api keys that are set
via the user/on the platform is testing on dev

The first agent is my

Any API → Instant App
"Transform any API documentation into a fully functional web
application. Just provide a docs URL and get a complete, ready-to-deploy
app pushed to a new GitHub repository."

[Any API → Instant
App_v36.json](https://github.com/user-attachments/files/24600326/Any.API.Instant.App_v36.json)


The second agent is my
Idea to project
"Simply enter your coding project's idea and this agent will make all of
the base initial code needed for you to start working on that project
and place it on github for you!"

[Idea to
project_v11.json](https://github.com/user-attachments/files/24600346/Idea.to.project_v11.json)

If you have any questions or issues let me know.

References
https://e2b.dev/blog/python-guide-run-claude-code-in-an-e2b-sandbox

https://github.com/e2b-dev/e2b-cookbook/tree/main/examples/anthropic-claude-code-in-sandbox-python
https://code.claude.com/docs/en/cli-reference

I tried to use E2b's "anthropic-claude-code" template but it kept
complaining it was out of date, so I make it manually spin up a E2b
instance and make it install the latest claude code and it uses that
2026-01-23 10:05:32 +00:00
Nicholas Tindle
90466908a8 refactor(docs): restructure platform docs for GitBook and remove MkDo… (#11825)
<!-- Clearly explain the need for these changes: -->
we met some reality when merging into the docs site but this fixes it
### Changes 🏗️
updates paths, adds some guides
<!-- Concisely describe all of the changes made in this pull request:
-->
update to match reality
### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] deploy it and validate

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Aligns block integrations documentation with GitBook.
> 
> - Changes generator default output to
`docs/integrations/block-integrations` and writes overview `README.md`
and `SUMMARY.md` at `docs/integrations/`
> - Adds GitBook frontmatter and hint syntax to overview; prefixes block
links with `block-integrations/`
> - Introduces `generate_summary_md` to build GitBook navigation
(including optional `guides/`)
> - Preserves per-block manual sections and adds optional `extras` +
file-level `additional_content`
> - Updates sync checker to validate parent `README.md` and `SUMMARY.md`
> - Rewrites `docs/integrations/README.md` with GitBook frontmatter and
updated links; adds `docs/integrations/SUMMARY.md`
> - Adds new guides: `guides/llm-providers.md`,
`guides/voice-providers.md`
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
fdb7ff8111. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: bobby.gaffin <bobby.gaffin@agpt.co>
2026-01-23 06:18:16 +00:00
300 changed files with 7131 additions and 12816 deletions

View File

@@ -122,24 +122,6 @@ class ConnectionManager:
return len(connections)
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
"""Broadcast a message to all active websocket connections."""
message = WSMessage(
method=method,
data=data,
).model_dump_json()
connections = tuple(self.active_connections)
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()

View File

@@ -176,64 +176,30 @@ async def get_execution_analytics_config(
# Return with provider prefix for clarity
return f"{provider_name}: {model_name}"
# Get all models from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
from backend.server.v2.llm import db as llm_db
# Get the recommended model from the database (configurable via admin UI)
recommended_model_slug = await llm_db.get_recommended_model_slug()
# Build the available models list
first_enabled_slug = None
for registry_model in llm_registry.iter_dynamic_models():
# Only include enabled models in the list
if not registry_model.is_enabled:
continue
# Track first enabled model as fallback
if first_enabled_slug is None:
first_enabled_slug = registry_model.slug
model_enum = LlmModel(registry_model.slug) # Create enum instance from slug
label = generate_model_label(model_enum)
# Include all LlmModel values (no more filtering by hardcoded list)
recommended_model = LlmModel.GPT4O_MINI.value
for model in LlmModel:
label = generate_model_label(model)
# Add "(Recommended)" suffix to the recommended model
if registry_model.slug == recommended_model_slug:
if model.value == recommended_model:
label += " (Recommended)"
available_models.append(
ModelInfo(
value=registry_model.slug,
value=model.value,
label=label,
provider=registry_model.metadata.provider,
provider=model.provider,
)
)
# Sort models by provider and name for better UX
available_models.sort(key=lambda x: (x.provider, x.label))
# Handle case where no models are available
if not available_models:
logger.warning(
"No enabled LLM models found in registry. "
"Ensure models are configured and enabled in the LLM Registry."
)
# Provide a placeholder entry so admins see meaningful feedback
available_models.append(
ModelInfo(
value="",
label="No models available - configure in LLM Registry",
provider="none",
)
)
# Use the DB recommended model, or fallback to first enabled model
final_recommended = recommended_model_slug or first_enabled_slug or ""
return ExecutionAnalyticsConfig(
available_models=available_models,
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
default_user_prompt=DEFAULT_USER_PROMPT,
recommended_model=final_recommended,
recommended_model=recommended_model,
)

View File

@@ -1,595 +0,0 @@
import logging
import autogpt_libs.auth
import fastapi
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
tags=["llm", "admin"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
async def _refresh_runtime_state() -> None:
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
logger.info("Refreshing LLM registry runtime state...")
try:
# Refresh registry from database
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
logger.info("Cleared all block schema caches")
# Clear the /blocks endpoint cache so frontend gets updated schemas
try:
from backend.api.features.v1 import _get_cached_blocks
_get_cached_blocks.cache_clear()
logger.info("Cleared /blocks endpoint cache")
except Exception as e:
logger.warning("Failed to clear /blocks cache: %s", e)
# Clear the v2 builder caches (if they exist)
try:
from backend.api.features.builder import db as builder_db
if hasattr(builder_db, "_get_all_providers"):
builder_db._get_all_providers.cache_clear()
logger.info("Cleared v2 builder providers cache")
if hasattr(builder_db, "_build_cached_search_results"):
builder_db._build_cached_search_results.cache_clear()
logger.info("Cleared v2 builder search results cache")
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)
# Notify all executor services to refresh their registry cache
from backend.data.llm_registry import publish_registry_refresh_notification
await publish_registry_refresh_notification()
logger.info("Published registry refresh notification")
except Exception as exc:
logger.exception(
"LLM runtime state refresh failed; caches may be stale: %s", exc
)
@router.get(
"/providers",
summary="List LLM providers",
response_model=llm_model.LlmProvidersResponse,
)
async def list_llm_providers(include_models: bool = True):
providers = await llm_db.list_providers(include_models=include_models)
return llm_model.LlmProvidersResponse(providers=providers)
@router.post(
"/providers",
summary="Create LLM provider",
response_model=llm_model.LlmProvider,
)
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
provider = await llm_db.upsert_provider(request=request)
await _refresh_runtime_state()
return provider
@router.patch(
"/providers/{provider_id}",
summary="Update LLM provider",
response_model=llm_model.LlmProvider,
)
async def update_llm_provider(
provider_id: str,
request: llm_model.UpsertLlmProviderRequest,
):
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
await _refresh_runtime_state()
return provider
@router.delete(
"/providers/{provider_id}",
summary="Delete LLM provider",
response_model=dict,
)
async def delete_llm_provider(provider_id: str):
"""
Delete an LLM provider.
A provider can only be deleted if it has no associated models.
Delete all models from the provider first before deleting the provider.
"""
try:
await llm_db.delete_provider(provider_id)
await _refresh_runtime_state()
logger.info("Deleted LLM provider '%s'", provider_id)
return {"success": True, "message": "Provider deleted successfully"}
except ValueError as e:
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=500, detail=str(e))
@router.get(
"/models",
summary="List LLM models",
response_model=llm_model.LlmModelsResponse,
)
async def list_llm_models(
provider_id: str | None = fastapi.Query(default=None),
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = fastapi.Query(
default=50, ge=1, le=100, description="Number of models per page"
),
):
return await llm_db.list_models(
provider_id=provider_id, page=page, page_size=page_size
)
@router.post(
"/models",
summary="Create LLM model",
response_model=llm_model.LlmModel,
)
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
model = await llm_db.create_model(request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}",
summary="Update LLM model",
response_model=llm_model.LlmModel,
)
async def update_llm_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
):
model = await llm_db.update_model(model_id=model_id, request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}/toggle",
summary="Toggle LLM model availability",
response_model=llm_model.ToggleLlmModelResponse,
)
async def toggle_llm_model(
model_id: str,
request: llm_model.ToggleLlmModelRequest,
):
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
If disabling a model and `migrate_to_slug` is provided, all workflows using
this model will be migrated to the specified replacement model before disabling.
A migration record is created which can be reverted later using the revert endpoint.
Optional fields:
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
- `custom_credit_cost`: Custom pricing override for billing during migration
"""
try:
result = await llm_db.toggle_model(
model_id=model_id,
is_enabled=request.is_enabled,
migrate_to_slug=request.migrate_to_slug,
migration_reason=request.migration_reason,
custom_credit_cost=request.custom_credit_cost,
)
await _refresh_runtime_state()
if result.nodes_migrated > 0:
logger.info(
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
result.model.slug,
"enabled" if request.is_enabled else "disabled",
result.nodes_migrated,
result.migrated_to_slug,
result.migration_id,
)
return result
except ValueError as exc:
logger.warning("Model toggle validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to toggle model availability",
) from exc
@router.get(
"/models/{model_id}/usage",
summary="Get model usage count",
response_model=llm_model.LlmModelUsageResponse,
)
async def get_llm_model_usage(model_id: str):
"""Get the number of workflow nodes using this model."""
try:
return await llm_db.get_model_usage(model_id=model_id)
except ValueError as exc:
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to get model usage %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get model usage",
) from exc
@router.delete(
"/models/{model_id}",
summary="Delete LLM model and migrate workflows",
response_model=llm_model.DeleteLlmModelResponse,
)
async def delete_llm_model(
model_id: str,
replacement_model_slug: str | None = fastapi.Query(
default=None,
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
),
):
"""
Delete a model and optionally migrate workflows using it to a replacement model.
If no workflows are using this model, it can be deleted without providing a
replacement. If workflows exist, replacement_model_slug is required.
This endpoint:
1. Counts how many workflow nodes use the model being deleted
2. If nodes exist, validates the replacement model and migrates them
3. Deletes the model record
4. Refreshes all caches and notifies executors
Example: DELETE /admin/llm/models/{id}?replacement_model_slug=gpt-4o
Example (no usage): DELETE /admin/llm/models/{id}
"""
try:
result = await llm_db.delete_model(
model_id=model_id, replacement_model_slug=replacement_model_slug
)
await _refresh_runtime_state()
logger.info(
"Deleted model '%s' and migrated %d nodes to '%s'",
result.deleted_model_slug,
result.nodes_migrated,
result.replacement_model_slug,
)
return result
except ValueError as exc:
# Validation errors (model not found, replacement invalid, etc.)
logger.warning("Model deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete model and migrate workflows",
) from exc
# ============================================================================
# Migration Management Endpoints
# ============================================================================
@router.get(
"/migrations",
summary="List model migrations",
response_model=llm_model.LlmMigrationsResponse,
)
async def list_llm_migrations(
include_reverted: bool = fastapi.Query(
default=False, description="Include reverted migrations in the list"
),
):
"""
List all model migrations.
Migrations are created when disabling a model with the migrate_to_slug option.
They can be reverted to restore the original model configuration.
"""
try:
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
return llm_model.LlmMigrationsResponse(migrations=migrations)
except Exception as exc:
logger.exception("Failed to list migrations: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list migrations",
) from exc
@router.get(
"/migrations/{migration_id}",
summary="Get migration details",
response_model=llm_model.LlmModelMigration,
)
async def get_llm_migration(migration_id: str):
"""Get details of a specific migration."""
try:
migration = await llm_db.get_migration(migration_id)
if not migration:
raise fastapi.HTTPException(
status_code=404, detail=f"Migration '{migration_id}' not found"
)
return migration
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get migration",
) from exc
@router.post(
"/migrations/{migration_id}/revert",
summary="Revert a model migration",
response_model=llm_model.RevertMigrationResponse,
)
async def revert_llm_migration(
migration_id: str,
request: llm_model.RevertMigrationRequest | None = None,
):
"""
Revert a model migration, restoring affected workflows to their original model.
This only reverts the specific nodes that were part of the migration.
The source model must exist for the revert to succeed.
Options:
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
Response includes:
- `nodes_reverted`: Number of nodes successfully reverted
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
- `source_model_re_enabled`: Whether the source model was re-enabled
Requirements:
- Migration must not already be reverted
- Source model must exist
"""
try:
re_enable = request.re_enable_source_model if request else True
result = await llm_db.revert_migration(
migration_id,
re_enable_source_model=re_enable,
)
await _refresh_runtime_state()
logger.info(
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
"(%d already changed, source re-enabled=%s)",
migration_id,
result.nodes_reverted,
result.target_model_slug,
result.source_model_slug,
result.nodes_already_changed,
result.source_model_re_enabled,
)
return result
except ValueError as exc:
logger.warning("Migration revert validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to revert migration",
) from exc
# ============================================================================
# Creator Management Endpoints
# ============================================================================
@router.get(
"/creators",
summary="List model creators",
response_model=llm_model.LlmCreatorsResponse,
)
async def list_llm_creators():
"""
List all model creators.
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
This is distinct from providers who host/serve the models (e.g., OpenRouter).
"""
try:
creators = await llm_db.list_creators()
return llm_model.LlmCreatorsResponse(creators=creators)
except Exception as exc:
logger.exception("Failed to list creators: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list creators",
) from exc
@router.get(
"/creators/{creator_id}",
summary="Get creator details",
response_model=llm_model.LlmModelCreator,
)
async def get_llm_creator(creator_id: str):
"""Get details of a specific model creator."""
try:
creator = await llm_db.get_creator(creator_id)
if not creator:
raise fastapi.HTTPException(
status_code=404, detail=f"Creator '{creator_id}' not found"
)
return creator
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get creator",
) from exc
@router.post(
"/creators",
summary="Create model creator",
response_model=llm_model.LlmModelCreator,
)
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
"""
Create a new model creator.
A creator represents an organization that creates/trains AI models,
such as OpenAI, Anthropic, Meta, or Google.
"""
try:
creator = await llm_db.upsert_creator(request=request)
await _refresh_runtime_state()
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
return creator
except Exception as exc:
logger.exception("Failed to create creator: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to create creator",
) from exc
@router.patch(
"/creators/{creator_id}",
summary="Update model creator",
response_model=llm_model.LlmModelCreator,
)
async def update_llm_creator(
creator_id: str,
request: llm_model.UpsertLlmCreatorRequest,
):
"""Update an existing model creator."""
try:
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
await _refresh_runtime_state()
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
return creator
except Exception as exc:
logger.exception("Failed to update creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to update creator",
) from exc
@router.delete(
"/creators/{creator_id}",
summary="Delete model creator",
response_model=dict,
)
async def delete_llm_creator(creator_id: str):
"""
Delete a model creator.
This will remove the creator association from all models that reference it
(sets creatorId to NULL), but will not delete the models themselves.
"""
try:
await llm_db.delete_creator(creator_id)
await _refresh_runtime_state()
logger.info("Deleted model creator '%s'", creator_id)
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
except ValueError as exc:
logger.warning("Creator deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete creator",
) from exc
# ============================================================================
# Recommended Model Endpoints
# ============================================================================
@router.get(
"/recommended-model",
summary="Get recommended model",
response_model=llm_model.RecommendedModelResponse,
)
async def get_recommended_model():
"""
Get the currently recommended LLM model.
The recommended model is shown to users as the default/suggested option
in model selection dropdowns.
"""
try:
model = await llm_db.get_recommended_model()
return llm_model.RecommendedModelResponse(
model=model,
slug=model.slug if model else None,
)
except Exception as exc:
logger.exception("Failed to get recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get recommended model",
) from exc
@router.post(
"/recommended-model",
summary="Set recommended model",
response_model=llm_model.SetRecommendedModelResponse,
)
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
"""
Set a model as the recommended model.
This clears the recommended flag from any other model and sets it on
the specified model. The model must be enabled to be set as recommended.
The recommended model is displayed to users as the default/suggested
option in model selection dropdowns throughout the platform.
"""
try:
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
await _refresh_runtime_state()
logger.info(
"Set recommended model to '%s' (previous: %s)",
model.slug,
previous_slug or "none",
)
return llm_model.SetRecommendedModelResponse(
model=model,
previous_recommended_slug=previous_slug,
message=f"Model '{model.display_name}' is now the recommended model",
)
except ValueError as exc:
logger.warning("Set recommended model validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to set recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to set recommended model",
) from exc

View File

@@ -1,491 +0,0 @@
import json
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
import backend.api.features.admin.llm_routes as llm_routes
from backend.server.v2.llm import model as llm_model
from backend.util.models import Pagination
app = fastapi.FastAPI()
app.include_router(llm_routes.router, prefix="/admin/llm")
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_list_llm_providers_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM providers"""
# Mock the database function
mock_providers = [
{
"id": "provider-1",
"name": "openai",
"display_name": "OpenAI",
"description": "OpenAI LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
{
"id": "provider-2",
"name": "anthropic",
"display_name": "Anthropic",
"description": "Anthropic LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
]
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_providers",
new=AsyncMock(return_value=mock_providers),
)
response = client.get("/admin/llm/providers")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["providers"]) == 2
assert response_data["providers"][0]["name"] == "openai"
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_providers_success.json",
)
def test_list_llm_models_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM models with pagination"""
# Mock the database function - now returns LlmModelsResponse
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=True,
capabilities={},
metadata={},
costs=[
llm_model.LlmModelCost(
id="cost-1",
credit_cost=10,
credential_provider="openai",
metadata={},
)
],
)
mock_response = llm_model.LlmModelsResponse(
models=[mock_model],
pagination=Pagination(
total_items=1,
total_pages=1,
current_page=1,
page_size=50,
),
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_models",
new=AsyncMock(return_value=mock_response),
)
response = client.get("/admin/llm/models")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["models"]) == 1
assert response_data["models"][0]["slug"] == "gpt-4o"
assert response_data["pagination"]["total_items"] == 1
assert response_data["pagination"]["page_size"] == 50
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_models_success.json",
)
def test_create_llm_provider_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM provider"""
mock_provider = {
"id": "new-provider-id",
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
new=AsyncMock(return_value=mock_provider),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
response = client.post("/admin/llm/providers", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["name"] == "groq"
assert response_data["display_name"] == "Groq"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_provider_success.json",
)
def test_create_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM model"""
mock_model = {
"id": "new-model-id",
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-id",
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.create_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
response = client.post("/admin/llm/models", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["slug"] == "gpt-4.1-mini"
assert response_data["is_enabled"] is True
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_model_success.json",
)
def test_update_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful update of LLM model"""
mock_model = {
"id": "model-1",
"slug": "gpt-4o",
"display_name": "GPT-4o Updated",
"description": "Updated description",
"provider_id": "provider-1",
"context_window": 256000,
"max_output_tokens": 32768,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-1",
"credit_cost": 15,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.update_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"display_name": "GPT-4o Updated",
"description": "Updated description",
"context_window": 256000,
"max_output_tokens": 32768,
}
response = client.patch("/admin/llm/models/model-1", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["display_name"] == "GPT-4o Updated"
assert response_data["context_window"] == 256000
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"update_llm_model_success.json",
)
def test_toggle_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful toggling of LLM model enabled status"""
# Create a proper mock model object
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=False,
capabilities={},
metadata={},
costs=[],
)
# Create a proper ToggleLlmModelResponse
mock_response = llm_model.ToggleLlmModelResponse(
model=mock_model,
nodes_migrated=0,
migrated_to_slug=None,
migration_id=None,
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {"is_enabled": False}
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["model"]["is_enabled"] is False
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"toggle_llm_model_success.json",
)
def test_delete_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful deletion of LLM model with migration"""
# Create a proper DeleteLlmModelResponse
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="gpt-3.5-turbo",
deleted_model_display_name="GPT-3.5 Turbo",
replacement_model_slug="gpt-4o-mini",
nodes_migrated=42,
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete(
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
)
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
assert response_data["nodes_migrated"] == 42
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"delete_llm_model_success.json",
)
def test_delete_llm_model_validation_error(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails with proper error when validation fails"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
)
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
assert response.status_code == 400
assert "Replacement model 'invalid' not found" in response.json()["detail"]
def test_delete_llm_model_no_replacement_with_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails when nodes exist but no replacement is provided"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(
side_effect=ValueError(
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
"Please provide a replacement_model_slug to migrate them."
)
),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 400
assert "workflow node(s) are using it" in response.json()["detail"]
def test_delete_llm_model_no_replacement_no_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="unused-model",
deleted_model_display_name="Unused Model",
replacement_model_slug=None,
nodes_migrated=0,
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "unused-model"
assert response_data["nodes_migrated"] == 0
assert response_data["replacement_model_slug"] is None
mock_refresh.assert_called_once()

View File

@@ -15,7 +15,6 @@ from backend.blocks import load_all_blocks
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema
from backend.data.llm_registry import get_all_model_slugs_for_validation
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -32,14 +31,7 @@ from .model import (
)
logger = logging.getLogger(__name__)
def _get_llm_models() -> list[str]:
"""Get LLM model names for search matching from the registry."""
return [
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
]
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
@@ -504,8 +496,8 @@ async def _get_static_counts():
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
# Check if query matches any value in llm_models from registry
if any(query in name for name in _get_llm_models()):
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False

View File

@@ -1,29 +1,28 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
apply_agent_patch,
AgentGeneratorNotConfiguredError,
decompose_goal,
generate_agent,
generate_agent_patch,
get_agent_as_json,
json_to_graph,
save_agent_to_library,
)
from .fixer import apply_all_fixes
from .utils import get_blocks_info
from .validator import validate_agent
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
# Core functions
"decompose_goal",
"generate_agent",
"generate_agent_patch",
"apply_agent_patch",
"save_agent_to_library",
"get_agent_as_json",
# Fixer
"apply_all_fixes",
# Validator
"validate_agent",
# Utils
"get_blocks_info",
"json_to_graph",
# Exceptions
"AgentGeneratorNotConfiguredError",
# Service
"is_external_service_configured",
"check_external_service_health",
]

View File

@@ -1,25 +0,0 @@
"""OpenRouter client configuration for agent generation."""
import os
from openai import AsyncOpenAI
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
# OpenRouter client (OpenAI-compatible API)
_client: AsyncOpenAI | None = None
def get_client() -> AsyncOpenAI:
"""Get or create the OpenRouter client."""
global _client
if _client is None:
if not OPENROUTER_API_KEY:
raise ValueError("OPENROUTER_API_KEY environment variable is required")
_client = AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=OPENROUTER_API_KEY,
)
return _client

View File

@@ -1,7 +1,5 @@
"""Core agent generation functions."""
import copy
import json
import logging
import uuid
from typing import Any
@@ -9,13 +7,35 @@ from typing import Any
from backend.api.features.library import db as library_db
from backend.data.graph import Graph, Link, Node, create_graph
from .client import AGENT_GENERATOR_MODEL, get_client
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
from .utils import get_block_summaries, parse_json_from_llm
from .service import (
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
logger = logging.getLogger(__name__)
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
"""Break down a goal into steps or return clarifying questions.
@@ -28,40 +48,13 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
client = get_client()
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
full_description = description
if context:
full_description = f"{description}\n\nAdditional context:\n{context}"
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": full_description},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for decomposition")
return None
result = parse_json_from_llm(content)
if result is None:
logger.error(f"Failed to parse decomposition response: {content[:200]}")
return None
return result
except Exception as e:
logger.error(f"Error decomposing goal: {e}")
return None
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
return await decompose_goal_external(description, context)
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
@@ -72,31 +65,14 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
Returns:
Agent JSON dict or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
client = get_client()
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": json.dumps(instructions, indent=2)},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for agent generation")
return None
result = parse_json_from_llm(content)
if result is None:
logger.error(f"Failed to parse agent JSON: {content[:200]}")
return None
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(instructions)
if result:
# Ensure required fields
if "id" not in result:
result["id"] = str(uuid.uuid4())
@@ -104,12 +80,7 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
except Exception as e:
logger.error(f"Error generating agent: {e}")
return None
return result
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
@@ -284,108 +255,23 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Generate a patch to update an existing agent.
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
Returns:
Patch dict or clarifying questions, or None on error
Updated agent JSON, clarifying questions dict, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
client = get_client()
prompt = PATCH_PROMPT.format(
current_agent=json.dumps(current_agent, indent=2),
block_summaries=get_block_summaries(),
)
try:
response = await client.chat.completions.create(
model=AGENT_GENERATOR_MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": update_request},
],
temperature=0,
)
content = response.choices[0].message.content
if content is None:
logger.error("LLM returned empty content for patch generation")
return None
return parse_json_from_llm(content)
except Exception as e:
logger.error(f"Error generating patch: {e}")
return None
def apply_agent_patch(
current_agent: dict[str, Any], patch: dict[str, Any]
) -> dict[str, Any]:
"""Apply a patch to an existing agent.
Args:
current_agent: Current agent JSON
patch: Patch dict with operations
Returns:
Updated agent JSON
"""
agent = copy.deepcopy(current_agent)
patches = patch.get("patches", [])
for p in patches:
patch_type = p.get("type")
if patch_type == "modify":
node_id = p.get("node_id")
changes = p.get("changes", {})
for node in agent.get("nodes", []):
if node["id"] == node_id:
_deep_update(node, changes)
logger.debug(f"Modified node {node_id}")
break
elif patch_type == "add":
new_nodes = p.get("new_nodes", [])
new_links = p.get("new_links", [])
agent["nodes"] = agent.get("nodes", []) + new_nodes
agent["links"] = agent.get("links", []) + new_links
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
elif patch_type == "remove":
node_ids_to_remove = set(p.get("node_ids", []))
link_ids_to_remove = set(p.get("link_ids", []))
# Remove nodes
agent["nodes"] = [
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
]
# Remove links (both explicit and those referencing removed nodes)
agent["links"] = [
link
for link in agent.get("links", [])
if link["id"] not in link_ids_to_remove
and link["source_id"] not in node_ids_to_remove
and link["sink_id"] not in node_ids_to_remove
]
logger.debug(
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
)
return agent
def _deep_update(target: dict, source: dict) -> None:
"""Recursively update a dict with another dict."""
for key, value in source.items():
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
_deep_update(target[key], value)
else:
target[key] = value
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(update_request, current_agent)

View File

@@ -1,606 +0,0 @@
"""Agent fixer - Fixes common LLM generation errors."""
import logging
import re
import uuid
from typing import Any
from .utils import (
ADDTODICTIONARY_BLOCK_ID,
ADDTOLIST_BLOCK_ID,
CODE_EXECUTION_BLOCK_ID,
CONDITION_BLOCK_ID,
CREATEDICT_BLOCK_ID,
CREATELIST_BLOCK_ID,
DATA_SAMPLING_BLOCK_ID,
DOUBLE_CURLY_BRACES_BLOCK_IDS,
GET_CURRENT_DATE_BLOCK_ID,
STORE_VALUE_BLOCK_ID,
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
get_blocks_info,
is_valid_uuid,
)
logger = logging.getLogger(__name__)
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix invalid UUIDs in agent and link IDs."""
# Fix agent ID
if not is_valid_uuid(agent.get("id", "")):
agent["id"] = str(uuid.uuid4())
logger.debug(f"Fixed agent ID: {agent['id']}")
# Fix node IDs
id_mapping = {} # Old ID -> New ID
for node in agent.get("nodes", []):
if not is_valid_uuid(node.get("id", "")):
old_id = node.get("id", "")
new_id = str(uuid.uuid4())
id_mapping[old_id] = new_id
node["id"] = new_id
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
# Fix link IDs and update references
for link in agent.get("links", []):
if not is_valid_uuid(link.get("id", "")):
link["id"] = str(uuid.uuid4())
logger.debug(f"Fixed link ID: {link['id']}")
# Update source/sink IDs if they were remapped
if link.get("source_id") in id_mapping:
link["source_id"] = id_mapping[link["source_id"]]
if link.get("sink_id") in id_mapping:
link["sink_id"] = id_mapping[link["sink_id"]]
return agent
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix single curly braces to double in template blocks."""
for node in agent.get("nodes", []):
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
continue
input_data = node.get("input_default", {})
for key in ("prompt", "format"):
if key in input_data and isinstance(input_data[key], str):
original = input_data[key]
# Fix simple variable references: {var} -> {{var}}
fixed = re.sub(
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
r"{{\1}}",
original,
)
if fixed != original:
input_data[key] = fixed
logger.debug(f"Fixed curly braces in {key}")
return agent
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
# Find all ConditionBlock nodes
condition_node_ids = {
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
}
if not condition_node_ids:
return agent
new_nodes = []
new_links = []
processed_conditions = set()
for link in links:
sink_id = link.get("sink_id")
sink_name = link.get("sink_name")
# Check if this link goes to a ConditionBlock's value2
if sink_id in condition_node_ids and sink_name == "value2":
source_node = next(
(n for n in nodes if n["id"] == link.get("source_id")), None
)
# Skip if source is already a StoreValueBlock
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
continue
# Skip if we already processed this condition
if sink_id in processed_conditions:
continue
processed_conditions.add(sink_id)
# Create StoreValueBlock
store_node_id = str(uuid.uuid4())
store_node = {
"id": store_node_id,
"block_id": STORE_VALUE_BLOCK_ID,
"input_default": {"data": None},
"metadata": {"position": {"x": 0, "y": -100}},
}
new_nodes.append(store_node)
# Create link: original source -> StoreValueBlock
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": store_node_id,
"sink_name": "input",
"is_static": False,
}
)
# Update original link: StoreValueBlock -> ConditionBlock
link["source_id"] = store_node_id
link["source_name"] = "output"
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
if new_nodes:
agent["nodes"] = nodes + new_nodes
return agent
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
When an AddToList block is found:
1. Checks if there's a CreateListBlock before it
2. Removes CreateListBlock if linked directly to AddToList
3. Adds an empty AddToList block before the original
4. Ensures the original has a self-referencing link
"""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
new_nodes = []
original_addtolist_ids = set()
nodes_to_remove = set()
links_to_remove = []
# First pass: identify CreateListBlock nodes to remove
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATELIST_BLOCK_ID
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
# Second pass: process AddToList blocks
filtered_nodes = []
for node in nodes:
if node.get("id") in nodes_to_remove:
continue
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
original_addtolist_ids.add(node.get("id"))
node_id = node.get("id")
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
# Check if already has prerequisite
has_prereq = any(
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_name") == "updated_list"
for link in links
)
if not has_prereq:
# Remove links to "list" input (except self-reference)
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_id") != node_id
and link not in links_to_remove
):
links_to_remove.append(link)
# Create prerequisite AddToList block
prereq_id = str(uuid.uuid4())
prereq_node = {
"id": prereq_id,
"block_id": ADDTOLIST_BLOCK_ID,
"input_default": {"list": [], "entry": None, "entries": []},
"metadata": {
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
},
}
new_nodes.append(prereq_node)
# Link prerequisite to original
links.append(
{
"id": str(uuid.uuid4()),
"source_id": prereq_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added prerequisite AddToList block for {node_id}")
filtered_nodes.append(node)
# Remove marked links
filtered_links = [link for link in links if link not in links_to_remove]
# Add self-referencing links for original AddToList blocks
for node in filtered_nodes + new_nodes:
if (
node.get("block_id") == ADDTOLIST_BLOCK_ID
and node.get("id") in original_addtolist_ids
):
node_id = node.get("id")
has_self_ref = any(
link["source_id"] == node_id
and link["sink_id"] == node_id
and link["source_name"] == "updated_list"
and link["sink_name"] == "list"
for link in filtered_links
)
if not has_self_ref:
filtered_links.append(
{
"id": str(uuid.uuid4()),
"source_id": node_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added self-reference for AddToList {node_id}")
agent["nodes"] = filtered_nodes + new_nodes
agent["links"] = filtered_links
return agent
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
nodes_to_remove = set()
links_to_remove = []
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
if (
source_node
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
and link.get("source_name") == "response"
):
link["source_name"] = "stdout_logs"
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
return agent
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
links_to_remove = []
for node in nodes:
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
node_id = node.get("id")
input_default = node.get("input_default", {})
# Remove links to sample_size
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "sample_size"
):
links_to_remove.append(link)
# Set default
input_default["sample_size"] = 1
node["input_default"] = input_default
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
if links_to_remove:
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
node_lookup = {n.get("id"): n for n in nodes}
for link in links:
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_node = node_lookup.get(source_id)
sink_node = node_lookup.get(sink_id)
if not source_node or not sink_node:
continue
source_pos = source_node.get("metadata", {}).get("position", {})
sink_pos = sink_node.get("metadata", {}).get("position", {})
source_x = source_pos.get("x", 0)
sink_x = sink_pos.get("x", 0)
if abs(sink_x - source_x) < 800:
new_x = source_x + 800
if "metadata" not in sink_node:
sink_node["metadata"] = {}
if "position" not in sink_node["metadata"]:
sink_node["metadata"]["position"] = {}
sink_node["metadata"]["position"]["x"] = new_x
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
return agent
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
for node in agent.get("nodes", []):
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
input_default = node.get("input_default", {})
if "offset" in input_default:
offset = input_default["offset"]
if isinstance(offset, (int, float)) and offset < 0:
input_default["offset"] = abs(offset)
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
return agent
def fix_ai_model_parameter(
agent: dict[str, Any],
blocks_info: list[dict[str, Any]],
default_model: str = "gpt-4o",
) -> dict[str, Any]:
"""Add default model parameter to AI blocks if missing."""
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
# Check if block has AI category
categories = block.get("categories", [])
is_ai_block = any(
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
)
if is_ai_block:
input_default = node.get("input_default", {})
if "model" not in input_default:
input_default["model"] = default_model
node["input_default"] = input_default
logger.debug(
f"Added model '{default_model}' to AI block {node.get('id')}"
)
return agent
def fix_link_static_properties(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix is_static property based on source block's staticOutput."""
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
if not source_node:
continue
source_block = block_map.get(source_node.get("block_id"))
if not source_block:
continue
static_output = source_block.get("staticOutput", False)
if link.get("is_static") != static_output:
link["is_static"] = static_output
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
return agent
def fix_data_type_mismatch(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in nodes}
def get_property_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
type_mapping = {
"string": "string",
"text": "string",
"integer": "number",
"number": "number",
"float": "number",
"boolean": "boolean",
"bool": "boolean",
"array": "list",
"list": "list",
"object": "dictionary",
"dict": "dictionary",
"dictionary": "dictionary",
}
new_links = []
nodes_to_add = []
for link in links:
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
new_links.append(link)
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
new_links.append(link)
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_property_type(source_outputs, link.get("source_name", ""))
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
if (
source_type
and sink_type
and not are_types_compatible(source_type, sink_type)
):
# Insert type converter
converter_id = str(uuid.uuid4())
target_type = type_mapping.get(sink_type, sink_type)
converter_node = {
"id": converter_id,
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
"input_default": {"type": target_type},
"metadata": {"position": {"x": 0, "y": 100}},
}
nodes_to_add.append(converter_node)
# source -> converter
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": converter_id,
"sink_name": "value",
"is_static": False,
}
)
# converter -> sink
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": converter_id,
"source_name": "value",
"sink_id": link["sink_id"],
"sink_name": link["sink_name"],
"is_static": False,
}
)
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
else:
new_links.append(link)
if nodes_to_add:
agent["nodes"] = nodes + nodes_to_add
agent["links"] = new_links
return agent
def apply_all_fixes(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""Apply all fixes to an agent JSON.
Args:
agent: Agent JSON dict
blocks_info: Optional list of block info dicts for advanced fixes
Returns:
Fixed agent JSON
"""
# Basic fixes (no block info needed)
agent = fix_agent_ids(agent)
agent = fix_double_curly_braces(agent)
agent = fix_storevalue_before_condition(agent)
agent = fix_addtolist_blocks(agent)
agent = fix_addtodictionary_blocks(agent)
agent = fix_code_execution_output(agent)
agent = fix_data_sampling_sample_size(agent)
agent = fix_node_x_coordinates(agent)
agent = fix_getcurrentdate_offset(agent)
# Advanced fixes (require block info)
if blocks_info is None:
blocks_info = get_blocks_info()
agent = fix_ai_model_parameter(agent, blocks_info)
agent = fix_link_static_properties(agent, blocks_info)
agent = fix_data_type_mismatch(agent, blocks_info)
return agent

View File

@@ -1,225 +0,0 @@
"""Prompt templates for agent generation."""
DECOMPOSITION_PROMPT = """
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
---
FIRST: Analyze the user's goal and determine:
1) Design-time configuration (fixed settings that won't change per run)
2) Runtime inputs (values the agent's end-user will provide each time it runs)
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
- DO NOT ask for the actual value
- Instead, define it as an Agent Input with a clear name, type, and description
Only ask clarifying questions about design-time config that affects how you build the workflow:
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
- Business rules that must be hard-coded
IMPORTANT CLARIFICATIONS POLICY:
- Ask no more than five essential questions
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
- Do not ask for API keys or credentials; the platform handles those directly
- If there is enough information to infer reasonable defaults, prefer to propose defaults
---
GUIDELINES:
1. List each step as a numbered item
2. Describe the action clearly and specify inputs/outputs
3. Ensure steps are in logical, sequential order
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
5. Help the user reach their goal efficiently
---
RULES:
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
2. USE ONLY THE BLOCKS PROVIDED
3. ALL required_input fields must be provided
4. Data types of linked properties must match
5. Write expert-level prompts for AI-related blocks
---
CRITICAL BLOCK RESTRICTIONS:
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
3. ConditionBlock: value2 is reference, value1 is contrast
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
---
OUTPUT FORMAT:
If more information is needed:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
"keyword": "email_provider",
"example": "Gmail"
}}
]
}}
```
If ready to proceed:
```json
{{
"type": "instructions",
"steps": [
{{
"step_number": 1,
"block_name": "AgentShortTextInputBlock",
"description": "Get the URL of the content to analyze.",
"inputs": [{{"name": "name", "value": "URL"}}],
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
}}
]
}}
```
---
AVAILABLE BLOCKS:
{block_summaries}
"""
GENERATION_PROMPT = """
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
---
NODES:
Each node must include:
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
- `block_id`: The block identifier (must match an Allowed Block)
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
- `metadata`: Must contain:
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
- `customized_name`: Clear name describing this block's purpose in the workflow
---
LINKS:
Each link connects a source node's output to a sink node's input:
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
- `source_id`: ID of the source node
- `source_name`: Output field name from the source block
- `sink_id`: ID of the sink node
- `sink_name`: Input field name on the sink block
- `is_static`: true only if source block has static_output: true
CRITICAL: All IDs must be valid UUID v4 format!
---
AGENT (GRAPH):
Wrap nodes and links in:
- `id`: UUID of the agent
- `name`: Short, generic name (avoid specific company names, URLs)
- `description`: Short, generic description
- `nodes`: List of all nodes
- `links`: List of all links
- `version`: 1
- `is_active`: true
---
TIPS:
- All required_input fields must be provided via input_default or a valid link
- Ensure consistent source_id and sink_id references
- Avoid dangling links
- Input/output pins must match block schemas
- Do not invent unknown block_ids
---
ALLOWED BLOCKS:
{block_summaries}
---
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
"""
PATCH_PROMPT = """
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
CURRENT AGENT:
{current_agent}
AVAILABLE BLOCKS:
{block_summaries}
---
PATCH FORMAT:
Return a JSON object with the following structure:
```json
{{
"type": "patch",
"intent": "Brief description of what the patch does",
"patches": [
{{
"type": "modify",
"node_id": "uuid-of-node-to-modify",
"changes": {{
"input_default": {{"field": "new_value"}},
"metadata": {{"customized_name": "New Name"}}
}}
}},
{{
"type": "add",
"new_nodes": [
{{
"id": "new-uuid",
"block_id": "block-uuid",
"input_default": {{}},
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
}}
],
"new_links": [
{{
"id": "link-uuid",
"source_id": "source-node-id",
"source_name": "output_field",
"sink_id": "sink-node-id",
"sink_name": "input_field"
}}
]
}},
{{
"type": "remove",
"node_ids": ["uuid-of-node-to-remove"],
"link_ids": ["uuid-of-link-to-remove"]
}}
]
}}
```
If you need more information, return:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "What specific change do you want?",
"keyword": "change_type",
"example": "Add error handling"
}}
]
}}
```
Generate the minimal patch needed. Output ONLY valid JSON.
"""

View File

@@ -0,0 +1,269 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str, context: str = ""
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
Or None on error
"""
client = _get_client()
# Build the request payload
payload: dict[str, Any] = {"description": description}
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return None
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def generate_agent_external(
instructions: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
Returns:
Agent JSON dict or None on error
"""
client = _get_client()
try:
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
return data.get("agent_json")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def generate_agent_patch_external(
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
Returns:
Updated agent JSON, clarifying questions dict, or None on error
"""
client = _get_client()
try:
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -1,213 +0,0 @@
"""Utilities for agent generation."""
import json
import re
from typing import Any
from backend.data.block import get_blocks
# UUID validation regex
UUID_REGEX = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
)
# Block IDs for various fixes
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
"3b191d9f-356f-482d-8238-ba04b6d18381",
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
"716a67b3-6760-42e7-86dc-18645c6e00fc",
"530cf046-2ce0-4854-ae2c-659db17c7a46",
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
]
def is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID v4."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def _compact_schema(schema: dict) -> dict[str, str]:
"""Extract compact type info from a JSON schema properties dict.
Returns a dict of {field_name: type_string} for essential info only.
"""
props = schema.get("properties", {})
result = {}
for name, prop in props.items():
# Skip internal/complex fields
if name.startswith("_"):
continue
# Get type string
type_str = prop.get("type", "any")
# Handle anyOf/oneOf (optional types)
if "anyOf" in prop:
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
type_str = "|".join(types) if types else "any"
elif "allOf" in prop:
type_str = "object"
# Add array item type if present
if type_str == "array" and "items" in prop:
items = prop["items"]
if isinstance(items, dict):
item_type = items.get("type", "any")
type_str = f"array[{item_type}]"
result[name] = type_str
return result
def get_block_summaries(include_schemas: bool = True) -> str:
"""Generate compact block summaries for prompts.
Args:
include_schemas: Whether to include input/output type info
Returns:
Formatted string of block summaries (compact format)
"""
blocks = get_blocks()
summaries = []
for block_id, block_cls in blocks.items():
block = block_cls()
name = block.name
desc = getattr(block, "description", "") or ""
# Truncate description
if len(desc) > 150:
desc = desc[:147] + "..."
if not include_schemas:
summaries.append(f"- {name} (id: {block_id}): {desc}")
else:
# Compact format with type info only
inputs = {}
outputs = {}
required = []
if hasattr(block, "input_schema"):
try:
schema = block.input_schema.jsonschema()
inputs = _compact_schema(schema)
required = schema.get("required", [])
except Exception:
pass
if hasattr(block, "output_schema"):
try:
schema = block.output_schema.jsonschema()
outputs = _compact_schema(schema)
except Exception:
pass
# Build compact line format
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
req_str = f" req=[{','.join(required)}]" if required else ""
static = " [static]" if getattr(block, "static_output", False) else ""
line = f"- {name} (id: {block_id}): {desc}"
if in_str:
line += f"\n in: {{{in_str}}}{req_str}"
if out_str:
line += f"\n out: {{{out_str}}}{static}"
summaries.append(line)
return "\n".join(summaries)
def get_blocks_info() -> list[dict[str, Any]]:
"""Get block information with schemas for validation and fixing."""
blocks = get_blocks()
blocks_info = []
for block_id, block_cls in blocks.items():
block = block_cls()
blocks_info.append(
{
"id": block_id,
"name": block.name,
"description": getattr(block, "description", ""),
"categories": getattr(block, "categories", []),
"staticOutput": getattr(block, "static_output", False),
"inputSchema": (
block.input_schema.jsonschema()
if hasattr(block, "input_schema")
else {}
),
"outputSchema": (
block.output_schema.jsonschema()
if hasattr(block, "output_schema")
else {}
),
}
)
return blocks_info
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
"""Extract JSON from LLM response (handles markdown code blocks)."""
if not text:
return None
# Try fenced code block
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# Try raw text
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# Try finding {...} span
start = text.find("{")
end = text.rfind("}")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
# Try finding [...] span
start = text.find("[")
end = text.rfind("]")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
return None

View File

@@ -1,279 +0,0 @@
"""Agent validator - Validates agent structure and connections."""
import logging
import re
from typing import Any
from .utils import get_blocks_info
logger = logging.getLogger(__name__)
class AgentValidator:
"""Validator for AutoGPT agents with detailed error reporting."""
def __init__(self):
self.errors: list[str] = []
def add_error(self, error: str) -> None:
"""Add an error message."""
self.errors.append(error)
def validate_block_existence(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate all block IDs exist in the blocks library."""
valid = True
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
node_id = node.get("id")
if not block_id:
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
valid = False
continue
if block_id not in valid_block_ids:
self.add_error(
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
)
valid = False
return valid
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
"""Validate all node IDs referenced in links exist."""
valid = True
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
for link in agent.get("links", []):
link_id = link.get("id", "Unknown")
source_id = link.get("source_id")
sink_id = link.get("sink_id")
if not source_id:
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
valid = False
elif source_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent source_id '{source_id}'."
)
valid = False
if not sink_id:
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
valid = False
elif sink_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
)
valid = False
return valid
def validate_required_inputs(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate required inputs are provided."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
required_inputs = block.get("inputSchema", {}).get("required", [])
input_defaults = node.get("input_default", {})
node_id = node.get("id")
# Get linked inputs
linked_inputs = {
link["sink_name"]
for link in agent.get("links", [])
if link.get("sink_id") == node_id
}
for req_input in required_inputs:
if (
req_input not in input_defaults
and req_input not in linked_inputs
and req_input != "credentials"
):
block_name = block.get("name", "Unknown Block")
self.add_error(
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
)
valid = False
return valid
def validate_data_type_compatibility(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate linked data types are compatible."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
def get_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_type(source_outputs, link.get("source_name", ""))
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
if source_type and sink_type and not are_compatible(source_type, sink_type):
self.add_error(
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
)
valid = False
return valid
def validate_nested_sink_links(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate nested sink links (with _#_ notation)."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(link.get("sink_id"))
if not sink_node:
continue
block = block_map.get(sink_node.get("block_id"))
if not block:
continue
input_props = block.get("inputSchema", {}).get("properties", {})
parent_schema = input_props.get(parent)
if not parent_schema:
self.add_error(
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
)
valid = False
continue
if not parent_schema.get("additionalProperties"):
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and child in parent_schema.get("properties", {})
):
self.add_error(
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
)
valid = False
return valid
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
"""Validate prompts don't have spaces in template variables."""
valid = True
for node in agent.get("nodes", []):
input_default = node.get("input_default", {})
prompt = input_default.get("prompt", "")
if not isinstance(prompt, str):
continue
# Find {{...}} with spaces
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
for match in matches:
content = match.group(1)
if " " in content:
self.add_error(
f"Node '{node.get('id')}' has spaces in template variable: "
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
)
valid = False
return valid
def validate(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Run all validations.
Returns:
Tuple of (is_valid, error_message)
"""
self.errors = []
if blocks_info is None:
blocks_info = get_blocks_info()
checks = [
self.validate_block_existence(agent, blocks_info),
self.validate_link_node_references(agent),
self.validate_required_inputs(agent, blocks_info),
self.validate_data_type_compatibility(agent, blocks_info),
self.validate_nested_sink_links(agent, blocks_info),
self.validate_prompt_spaces(agent),
]
all_passed = all(checks)
if all_passed:
logger.info("Agent validation successful")
return True, None
error_message = "Agent validation failed:\n"
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
return False, error_message
def validate_agent(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Convenience function to validate an agent.
Returns:
Tuple of (is_valid, error_message)
"""
validator = AgentValidator()
return validator.validate(agent, blocks_info)

View File

@@ -8,12 +8,10 @@ from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
apply_all_fixes,
AgentGeneratorNotConfiguredError,
decompose_goal,
generate_agent,
get_blocks_info,
save_agent_to_library,
validate_agent,
)
from .base import BaseTool
from .models import (
@@ -27,9 +25,6 @@ from .models import (
logger = logging.getLogger(__name__)
# Maximum retries for agent generation with validation feedback
MAX_GENERATION_RETRIES = 2
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
@@ -91,9 +86,8 @@ class CreateAgentTool(BaseTool):
Flow:
1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON from the steps
3. Apply fixes to correct common LLM errors
4. Preview or save based on the save parameter
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
@@ -110,11 +104,13 @@ class CreateAgentTool(BaseTool):
# Step 1: Decompose goal into steps
try:
decomposition_result = await decompose_goal(description, context)
except ValueError as e:
# Handle missing API key or configuration errors
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=f"Agent generation is not configured: {str(e)}",
error="configuration_error",
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
@@ -171,72 +167,32 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Step 2: Generate agent JSON with retry on validation failure
blocks_info = get_blocks_info()
agent_json = None
validation_errors = None
for attempt in range(MAX_GENERATION_RETRIES + 1):
# Generate agent (include validation errors from previous attempt)
if attempt == 0:
agent_json = await generate_agent(decomposition_result)
else:
# Retry with validation error feedback
logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
retry_instructions = {
**decomposition_result,
"previous_errors": validation_errors,
"retry_instructions": (
"The previous generation had validation errors. "
"Please fix these issues in the new generation:\n"
f"{validation_errors}"
),
}
agent_json = await generate_agent(retry_instructions)
if agent_json is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate the agent. Please try again.",
error="Generation failed",
session_id=session_id,
)
continue
# Step 3: Apply fixes to correct common errors
agent_json = apply_all_fixes(agent_json, blocks_info)
# Step 4: Validate the agent
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
if is_valid:
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
# Step 2: Generate agent JSON (external service handles fixing and validation)
try:
agent_json = await generate_agent(decomposition_result)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if attempt == MAX_GENERATION_RETRIES:
# Return error with validation details
return ErrorResponse(
message=(
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
f"Please try rephrasing your request or simplify the workflow."
),
error="validation_failed",
details={"validation_errors": validation_errors},
session_id=session_id,
)
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. Please try again.",
error="Generation failed",
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Step 4: Preview or save
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(

View File

@@ -8,13 +8,10 @@ from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
apply_agent_patch,
apply_all_fixes,
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_blocks_info,
save_agent_to_library,
validate_agent,
)
from .base import BaseTool
from .models import (
@@ -28,9 +25,6 @@ from .models import (
logger = logging.getLogger(__name__)
# Maximum retries for patch generation with validation feedback
MAX_GENERATION_RETRIES = 2
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
@@ -43,7 +37,7 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates a patch to update the agent while preserving unchanged parts."
"Generates updates to the agent while preserving unchanged parts."
)
@property
@@ -98,9 +92,8 @@ class EditAgentTool(BaseTool):
Flow:
1. Fetch the current agent
2. Generate a patch based on the requested changes
3. Apply the patch to create an updated agent
4. Preview or save based on the save parameter
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
@@ -137,121 +130,58 @@ class EditAgentTool(BaseTool):
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
# Step 2: Generate patch with retry on validation failure
blocks_info = get_blocks_info()
updated_agent = None
validation_errors = None
intent = "Applied requested changes"
for attempt in range(MAX_GENERATION_RETRIES + 1):
# Generate patch (include validation errors from previous attempt)
try:
if attempt == 0:
patch_result = await generate_agent_patch(
update_request, current_agent
)
else:
# Retry with validation error feedback
logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
retry_request = (
f"{update_request}\n\n"
f"IMPORTANT: The previous edit had validation errors. "
f"Please fix these issues:\n{validation_errors}"
)
patch_result = await generate_agent_patch(
retry_request, current_agent
)
except ValueError as e:
# Handle missing API key or configuration errors
return ErrorResponse(
message=f"Agent generation is not configured: {str(e)}",
error="configuration_error",
session_id=session_id,
)
if patch_result is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate changes. Please try rephrasing.",
error="Patch generation failed",
session_id=session_id,
)
continue
# Check if LLM returned clarifying questions
if patch_result.get("type") == "clarifying_questions":
questions = patch_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Step 3: Apply patch and fixes
try:
updated_agent = apply_agent_patch(current_agent, patch_result)
updated_agent = apply_all_fixes(updated_agent, blocks_info)
except Exception as e:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message=f"Failed to apply changes: {str(e)}",
error="patch_apply_failed",
details={"exception": str(e)},
session_id=session_id,
)
validation_errors = str(e)
continue
# Step 4: Validate the updated agent
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
if is_valid:
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
intent = patch_result.get("intent", "Applied requested changes")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
# Step 2: Generate updated agent (external service handles fixing and validation)
try:
result = await generate_agent_patch(update_request, current_agent)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent editing is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if attempt == MAX_GENERATION_RETRIES:
# Return error with validation details
return ErrorResponse(
message=(
f"Updated agent has validation errors after "
f"{MAX_GENERATION_RETRIES + 1} attempts. "
f"Please try rephrasing your request or simplify the changes."
),
error="validation_failed",
details={"validation_errors": validation_errors},
session_id=session_id,
)
if result is None:
return ErrorResponse(
message="Failed to generate changes. Please try rephrasing.",
error="Update generation failed",
session_id=session_id,
)
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
assert updated_agent is not None
# Check if LLM returned clarifying questions
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Result is the updated agent JSON
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
# Step 5: Preview or save
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. Changes: {intent}. "
f"I've updated the agent. "
f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes."
),
@@ -277,10 +207,7 @@ class EditAgentTool(BaseTool):
)
return AgentSavedResponse(
message=(
f"Updated agent '{created_graph.name}' has been saved to your library! "
f"Changes: {intent}"
),
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,

View File

@@ -29,7 +29,7 @@ def mock_embedding_functions():
yield
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent(setup_test_data):
"""Test that the run_agent tool successfully executes an approved agent"""
# Use test data from fixture
@@ -70,7 +70,7 @@ async def test_run_agent(setup_test_data):
assert result_data["graph_name"] == "Test Agent"
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_missing_inputs(setup_test_data):
"""Test that the run_agent tool returns error when inputs are missing"""
# Use test data from fixture
@@ -106,7 +106,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
assert "message" in result_data
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_invalid_agent_id(setup_test_data):
"""Test that the run_agent tool returns error for invalid agent ID"""
# Use test data from fixture
@@ -141,7 +141,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
)
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
"""Test that run_agent works with an agent requiring LLM credentials"""
# Use test data from fixture
@@ -185,7 +185,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
assert result_data["graph_name"] == "LLM Test Agent"
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
"""Test that run_agent returns available inputs when called without inputs or use_defaults."""
user = setup_test_data["user"]
@@ -219,7 +219,7 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
assert "inputs" in result_data["message"].lower()
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_with_use_defaults(setup_test_data):
"""Test that run_agent executes successfully with use_defaults=True."""
user = setup_test_data["user"]
@@ -251,7 +251,7 @@ async def test_run_agent_with_use_defaults(setup_test_data):
assert result_data["graph_id"] == graph.id
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
"""Test that run_agent returns setup_requirements when credentials are missing."""
user = setup_firecrawl_test_data["user"]
@@ -285,7 +285,7 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_invalid_slug_format(setup_test_data):
"""Test that run_agent returns error for invalid slug format (no slash)."""
user = setup_test_data["user"]
@@ -313,7 +313,7 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
assert "username/agent-name" in result_data["message"]
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_unauthenticated():
"""Test that run_agent returns need_login for unauthenticated users."""
tool = RunAgentTool()
@@ -340,7 +340,7 @@ async def test_run_agent_unauthenticated():
assert "sign in" in result_data["message"].lower()
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_schedule_without_cron(setup_test_data):
"""Test that run_agent returns error when scheduling without cron expression."""
user = setup_test_data["user"]
@@ -372,7 +372,7 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
assert "cron" in result_data["message"].lower()
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_schedule_without_name(setup_test_data):
"""Test that run_agent returns error when scheduling without schedule_name."""
user = setup_test_data["user"]

View File

@@ -23,6 +23,7 @@ class PendingHumanReviewModel(BaseModel):
id: Unique identifier for the review record
user_id: ID of the user who must perform the review
node_exec_id: ID of the node execution that created this review
node_id: ID of the node definition (for grouping reviews from same node)
graph_exec_id: ID of the graph execution containing the node
graph_id: ID of the graph template being executed
graph_version: Version number of the graph template
@@ -37,6 +38,10 @@ class PendingHumanReviewModel(BaseModel):
"""
node_exec_id: str = Field(description="Node execution ID (primary key)")
node_id: str = Field(
description="Node definition ID (for grouping)",
default="", # Temporary default for test compatibility
)
user_id: str = Field(description="User ID associated with the review")
graph_exec_id: str = Field(description="Graph execution ID")
graph_id: str = Field(description="Graph ID")
@@ -66,7 +71,9 @@ class PendingHumanReviewModel(BaseModel):
)
@classmethod
def from_db(cls, review: "PendingHumanReview") -> "PendingHumanReviewModel":
def from_db(
cls, review: "PendingHumanReview", node_id: str
) -> "PendingHumanReviewModel":
"""
Convert a database model to a response model.
@@ -74,9 +81,14 @@ class PendingHumanReviewModel(BaseModel):
payload, instructions, and editable flag.
Handles invalid data gracefully by using safe defaults.
Args:
review: Database review object
node_id: Node definition ID (fetched from NodeExecution)
"""
return cls(
node_exec_id=review.nodeExecId,
node_id=node_id,
user_id=review.userId,
graph_exec_id=review.graphExecId,
graph_id=review.graphId,
@@ -107,6 +119,13 @@ class ReviewItem(BaseModel):
reviewed_data: SafeJsonData | None = Field(
None, description="Optional edited data (ignored if approved=False)"
)
auto_approve_future: bool = Field(
default=False,
description=(
"If true and this review is approved, future executions of this same "
"block (node) will be automatically approved. This only affects approved reviews."
),
)
@field_validator("reviewed_data")
@classmethod
@@ -174,6 +193,9 @@ class ReviewRequest(BaseModel):
This request must include ALL pending reviews for a graph execution.
Each review will be either approved (with optional data modifications)
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
Each review item can individually specify whether to auto-approve future executions
of the same block via the `auto_approve_future` field on ReviewItem.
"""
reviews: List[ReviewItem] = Field(

View File

@@ -1,17 +1,27 @@
import asyncio
import logging
from typing import List
from typing import Any, List
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.data.execution import get_graph_execution_meta
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import (
create_auto_approval_record,
get_pending_reviews_by_node_exec_ids,
get_pending_reviews_for_execution,
get_pending_reviews_for_user,
has_pending_reviews_for_graph_exec,
process_all_reviews_for_execution,
)
from backend.data.model import USER_TIMEZONE_NOT_SET
from backend.data.user import get_user_by_id
from backend.executor.utils import add_graph_execution
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
@@ -127,17 +137,70 @@ async def process_review_action(
detail="At least one review must be provided",
)
# Build review decisions map
# Batch fetch all requested reviews
reviews_map = await get_pending_reviews_by_node_exec_ids(
list(all_request_node_ids), user_id
)
# Validate all reviews were found
missing_ids = all_request_node_ids - set(reviews_map.keys())
if missing_ids:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No pending review found for node execution(s): {', '.join(missing_ids)}",
)
# Validate all reviews belong to the same execution
graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()}
if len(graph_exec_ids) > 1:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="All reviews in a single request must belong to the same execution.",
)
graph_exec_id = next(iter(graph_exec_ids))
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
review_decisions = {}
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
for review in request.reviews:
review_status = (
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
)
# If this review requested auto-approval, don't allow data modifications
reviewed_data = None if review.auto_approve_future else review.reviewed_data
review_decisions[review.node_exec_id] = (
review_status,
review.reviewed_data,
reviewed_data,
review.message,
)
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
# Process all reviews
updated_reviews = await process_all_reviews_for_execution(
@@ -145,6 +208,87 @@ async def process_review_action(
review_decisions=review_decisions,
)
# Create auto-approval records for approved reviews that requested it
# Deduplicate by node_id to avoid race conditions when multiple reviews
# for the same node are processed in parallel
async def create_auto_approval_for_node(
node_id: str, review_result
) -> tuple[str, bool]:
"""
Create auto-approval record for a node.
Returns (node_id, success) tuple for tracking failures.
"""
try:
await create_auto_approval_record(
user_id=user_id,
graph_exec_id=review_result.graph_exec_id,
graph_id=review_result.graph_id,
graph_version=review_result.graph_version,
node_id=node_id,
payload=review_result.payload,
)
return (node_id, True)
except Exception as e:
logger.error(
f"Failed to create auto-approval record for node {node_id}",
exc_info=e,
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
if review_result.status == ReviewStatus.APPROVED
and auto_approve_requests.get(node_exec_id, False)
]
# Batch-fetch node executions to get node_ids
nodes_needing_auto_approval: dict[str, Any] = {}
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
*[
create_auto_approval_for_node(node_id, review_result)
for node_id, review_result in nodes_needing_auto_approval.items()
],
return_exceptions=True,
)
# Count auto-approval failures
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
approved_count = sum(
1
@@ -157,30 +301,53 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume execution if we processed some reviews
# Resume execution only if ALL pending reviews for this execution have been processed
if updated_reviews:
# Get graph execution ID from any processed review
first_review = next(iter(updated_reviews.values()))
graph_exec_id = first_review.graph_exec_id
# Check if any pending reviews remain for this execution
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Resume execution
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)
execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
user_timezone=user_timezone,
)
await add_graph_execution(
graph_id=first_review.graph_id,
user_id=user_id,
graph_exec_id=graph_exec_id,
execution_context=execution_context,
)
logger.info(f"Resumed execution {graph_exec_id}")
except Exception as e:
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
# Build error message if auto-approvals failed
error_message = None
if auto_approval_failed_count > 0:
error_message = (
f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. "
f"You may need to manually approve these reviews in future executions."
)
return ReviewResponse(
approved_count=approved_count,
rejected_count=rejected_count,
failed_count=0,
error=None,
failed_count=auto_approval_failed_count,
error=error_message,
)

View File

@@ -583,7 +583,13 @@ async def update_library_agent(
)
update_fields["isDeleted"] = is_deleted
if settings is not None:
update_fields["settings"] = SafeJson(settings.model_dump())
existing_agent = await get_library_agent(id=library_agent_id, user_id=user_id)
current_settings_dict = (
existing_agent.settings.model_dump() if existing_agent.settings else {}
)
new_settings = settings.model_dump(exclude_unset=True)
merged_settings = {**current_settings_dict, **new_settings}
update_fields["settings"] = SafeJson(merged_settings)
try:
# If graph_version is provided, update to that specific version

View File

@@ -20,6 +20,7 @@ from typing import AsyncGenerator
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
@@ -38,13 +39,13 @@ keysmith = APIKeySmith()
# ============================================================================
@pytest.fixture
@pytest.fixture(scope="session")
def test_user_id() -> str:
"""Test user ID for OAuth tests."""
return str(uuid.uuid4())
@pytest.fixture
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def test_user(server, test_user_id: str):
"""Create a test user in the database."""
await PrismaUser.prisma().create(
@@ -67,7 +68,7 @@ async def test_user(server, test_user_id: str):
await PrismaUser.prisma().delete(where={"id": test_user_id})
@pytest.fixture
@pytest_asyncio.fixture
async def test_oauth_app(test_user: str):
"""Create a test OAuth application in the database."""
app_id = str(uuid.uuid4())
@@ -122,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]:
return generate_pkce()
@pytest.fixture
@pytest_asyncio.fixture
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
"""
Create an async HTTP client that talks directly to the FastAPI app.
@@ -287,7 +288,7 @@ async def test_authorize_invalid_client_returns_error(
assert query_params["error"][0] == "invalid_client"
@pytest.fixture
@pytest_asyncio.fixture
async def inactive_oauth_app(test_user: str):
"""Create an inactive test OAuth application in the database."""
app_id = str(uuid.uuid4())
@@ -1004,7 +1005,7 @@ async def test_token_refresh_revoked(
assert "revoked" in response.json()["detail"].lower()
@pytest.fixture
@pytest_asyncio.fixture
async def other_oauth_app(test_user: str):
"""Create a second OAuth application for cross-app tests."""
app_id = str(uuid.uuid4())

View File

@@ -1552,7 +1552,7 @@ async def review_store_submission(
# Generate embedding for approved listing (blocking - admin operation)
# Inside transaction: if embedding fails, entire transaction rolls back
embedding_success = await ensure_embedding(
await ensure_embedding(
version_id=store_listing_version_id,
name=store_listing_version.name,
description=store_listing_version.description,
@@ -1560,12 +1560,6 @@ async def review_store_submission(
categories=store_listing_version.categories or [],
tx=tx,
)
if not embedding_success:
raise ValueError(
f"Failed to generate embedding for listing {store_listing_version_id}. "
"This is likely due to OpenAI API being unavailable. "
"Please try again later or contact support if the issue persists."
)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": store_listing_version.StoreListing.id},

View File

@@ -21,7 +21,6 @@ from backend.util.json import dumps
logger = logging.getLogger(__name__)
# OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small"
# Embedding dimension for the model above
@@ -63,49 +62,42 @@ def build_searchable_text(
return " ".join(parts)
async def generate_embedding(text: str) -> list[float] | None:
async def generate_embedding(text: str) -> list[float]:
"""
Generate embedding for text using OpenAI API.
Returns None if embedding generation fails.
Fail-fast: no retries to maintain consistency with approval flow.
Raises exceptions on failure - caller should handle.
"""
try:
client = get_openai_client()
if not client:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
client = get_openai_client()
if not client:
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding")
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
latency_ms = (time.time() - start_time) * 1000
embedding = response.data[0].embedding
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
return embedding
else:
truncated_text = text
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
latency_ms = (time.time() - start_time) * 1000
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
async def store_embedding(
@@ -144,48 +136,45 @@ async def store_content_embedding(
New function for unified content embedding storage.
Uses raw SQL since Prisma doesn't natively support pgvector.
Raises exceptions on failure - caller should handle.
"""
try:
client = tx if tx else prisma.get_client()
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
# Use unqualified ::vector - pgvector is in search_path on all environments
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
# Use unqualified ::vector - pgvector is in search_path on all environments
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
)
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
except Exception as e:
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
return False
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
async def get_embedding(version_id: str) -> dict[str, Any] | None:
@@ -217,34 +206,31 @@ async def get_content_embedding(
New function for unified content embedding retrieval.
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
Raises exceptions on failure - caller should handle.
"""
try:
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
)
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
)
if result and len(result) > 0:
return result[0]
return None
except Exception as e:
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
return None
if result and len(result) > 0:
return result[0]
return None
async def ensure_embedding(
@@ -272,46 +258,38 @@ async def ensure_embedding(
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
True if embedding exists/was created
Raises exceptions on failure - caller should handle.
"""
try:
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
# Build searchable text for embedding
searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
# Build searchable text for embedding
searchable_text = build_searchable_text(name, description, sub_heading, categories)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Generate new embedding
embedding = await generate_embedding(searchable_text)
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
async def delete_embedding(version_id: str) -> bool:
@@ -521,6 +499,24 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
success = sum(1 for result in results if result is True)
failed = len(results) - success
# Aggregate unique errors to avoid Sentry spam
if failed > 0:
# Group errors by type and message
error_summary: dict[str, int] = {}
for result in results:
if isinstance(result, Exception):
error_key = f"{type(result).__name__}: {str(result)}"
error_summary[error_key] = error_summary.get(error_key, 0) + 1
# Log aggregated error summary
error_details = ", ".join(
f"{error} ({count}x)" for error, count in error_summary.items()
)
logger.error(
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
f"Errors: {error_details}"
)
results_by_type[content_type.value] = {
"processed": len(missing_items),
"success": success,
@@ -557,11 +553,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
}
async def embed_query(query: str) -> list[float] | None:
async def embed_query(query: str) -> list[float]:
"""
Generate embedding for a search query.
Same as generate_embedding but with clearer intent.
Raises exceptions on failure - caller should handle.
"""
return await generate_embedding(query)
@@ -594,40 +591,30 @@ async def ensure_content_embedding(
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
True if embedding exists/was created
Raises exceptions on failure - caller should handle.
"""
try:
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(
f"Embedding for {content_type}:{content_id} already exists"
)
return True
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for {content_type}:{content_id} already exists")
return True
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(
f"Could not generate embedding for {content_type}:{content_id}"
)
return False
# Generate new embedding
embedding = await generate_embedding(searchable_text)
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
return False
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
@@ -854,9 +841,8 @@ async def semantic_search(
limit = 100
# Generate query embedding
query_embedding = await embed_query(query)
if query_embedding is not None:
try:
query_embedding = await embed_query(query)
# Semantic search with embeddings
embedding_str = embedding_to_vector_string(query_embedding)
@@ -907,24 +893,21 @@ async def semantic_search(
"""
)
try:
results = await query_raw_with_schema(sql, *params)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.error(f"Semantic search failed: {e}")
# Fall through to lexical search below
results = await query_raw_with_schema(sql, *params)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.warning(f"Semantic search failed, falling back to lexical search: {e}")
# Fallback to lexical search if embeddings unavailable
logger.warning("Falling back to lexical search (embeddings unavailable)")
params_lexical: list[Any] = [limit]
user_filter = ""

View File

@@ -298,17 +298,16 @@ async def test_schema_handling_error_cases():
mock_client.execute_raw.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
# Should return False on error, not raise
assert result is False
# Should raise exception on error
with pytest.raises(Exception, match="Database error"):
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
if __name__ == "__main__":

View File

@@ -80,9 +80,8 @@ async def test_generate_embedding_no_api_key():
) as mock_get_client:
mock_get_client.return_value = None
result = await embeddings.generate_embedding("test text")
assert result is None
with pytest.raises(RuntimeError, match="openai_internal_api_key not set"):
await embeddings.generate_embedding("test text")
@pytest.mark.asyncio(loop_scope="session")
@@ -97,9 +96,8 @@ async def test_generate_embedding_api_error():
) as mock_get_client:
mock_get_client.return_value = mock_client
result = await embeddings.generate_embedding("test text")
assert result is None
with pytest.raises(Exception, match="API Error"):
await embeddings.generate_embedding("test text")
@pytest.mark.asyncio(loop_scope="session")
@@ -173,11 +171,10 @@ async def test_store_embedding_database_error(mocker):
embedding = [0.1, 0.2, 0.3]
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is False
with pytest.raises(Exception, match="Database error"):
await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
@pytest.mark.asyncio(loop_scope="session")
@@ -277,17 +274,16 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
"""Test ensure_embedding when generation fails."""
mock_get.return_value = None
mock_generate.return_value = None
mock_generate.side_effect = Exception("Generation failed")
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is False
with pytest.raises(Exception, match="Generation failed"):
await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -186,13 +186,12 @@ async def unified_hybrid_search(
offset = (page - 1) * page_size
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation if embedding unavailable
if query_embedding is None or not query_embedding:
# Generate query embedding with graceful degradation
try:
query_embedding = await embed_query(query)
except Exception as e:
logger.warning(
"Failed to generate query embedding - falling back to lexical-only search. "
f"Failed to generate query embedding - falling back to lexical-only search: {e}. "
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
)
query_embedding = [0.0] * EMBEDDING_DIM
@@ -464,13 +463,12 @@ async def hybrid_search(
offset = (page - 1) * page_size
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation
if query_embedding is None or not query_embedding:
# Generate query embedding with graceful degradation
try:
query_embedding = await embed_query(query)
except Exception as e:
logger.warning(
"Failed to generate query embedding - falling back to lexical-only search."
f"Failed to generate query embedding - falling back to lexical-only search: {e}"
)
query_embedding = [0.0] * EMBEDDING_DIM
total_non_semantic = (

View File

@@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings():
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate embedding failure
mock_embed.return_value = None
# Simulate embedding failure by raising exception
mock_embed.side_effect = Exception("Embedding generation failed")
mock_query.return_value = mock_results
# Should NOT raise - graceful degradation
@@ -613,7 +613,9 @@ async def test_unified_hybrid_search_graceful_degradation():
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = None # Embedding failure
mock_embed.side_effect = Exception(
"Embedding generation failed"
) # Embedding failure
# Should NOT raise - graceful degradation
results, total = await unified_hybrid_search(

View File

@@ -393,7 +393,6 @@ async def get_creators(
@router.get(
"/creator/{username}",
summary="Get creator details",
operation_id="getV2GetCreatorDetails",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)

View File

@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.llm_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -38,11 +37,9 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm.routes as public_llm_routes
import backend.util.service
import backend.util.settings
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -112,27 +109,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
# migrate_llm_models uses registry default model
from backend.blocks.llm import LlmModel
default_model_slug = llm_registry.get_default_model_slug()
if default_model_slug:
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
else:
logger.warning("Skipping LLM model migration: no default model available")
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
@@ -317,16 +298,6 @@ app.include_router(
tags=["v2", "executions", "review"],
prefix="/api/review",
)
app.include_router(
backend.api.features.admin.llm_routes.router,
tags=["v2", "admin", "llm"],
prefix="/api/llm/admin",
)
app.include_router(
public_llm_routes.router,
tags=["v2", "llm"],
prefix="/api",
)
app.include_router(
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
)

View File

@@ -77,39 +77,7 @@ async def event_broadcaster(manager: ConnectionManager):
payload=notification.payload,
)
async def registry_refresh_worker():
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
from backend.data.redis_client import connect_async
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
)
async for message in pubsub.listen():
if (
message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info(
"Broadcasting LLM registry refresh to all WebSocket clients"
)
await manager.broadcast_to_all(
method=WSMethod.NOTIFICATION,
data={
"type": "LLM_REGISTRY_REFRESH",
"event": "registry_updated",
},
)
await asyncio.gather(
execution_worker(),
notification_worker(),
registry_refresh_worker(),
)
await asyncio.gather(execution_worker(), notification_worker())
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -1,6 +1,7 @@
from typing import Any
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -9,7 +10,6 @@ from backend.blocks.llm import (
LlmModel,
LLMResponse,
llm_call,
llm_model_schema_extra,
)
from backend.data.block import (
BlockCategory,
@@ -50,10 +50,9 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for evaluating the condition.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
@@ -83,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -116,6 +116,7 @@ class PrintToConsoleBlock(Block):
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -0,0 +1,659 @@
import json
import shlex
import uuid
from typing import Literal, Optional
from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import BaseModel, SecretStr
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class ClaudeCodeExecutionError(Exception):
"""Exception raised when Claude Code execution fails.
Carries the sandbox_id so it can be returned to the user for cleanup
when dispose_sandbox=False.
"""
def __init__(self, message: str, sandbox_id: str = ""):
super().__init__(message)
self.sandbox_id = sandbox_id
# Test credentials for E2B
TEST_E2B_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="e2b",
api_key=SecretStr("mock-e2b-api-key"),
title="Mock E2B API key",
expires_at=None,
)
TEST_E2B_CREDENTIALS_INPUT = {
"provider": TEST_E2B_CREDENTIALS.provider,
"id": TEST_E2B_CREDENTIALS.id,
"type": TEST_E2B_CREDENTIALS.type,
"title": TEST_E2B_CREDENTIALS.title,
}
# Test credentials for Anthropic
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
provider="anthropic",
api_key=SecretStr("mock-anthropic-api-key"),
title="Mock Anthropic API key",
expires_at=None,
)
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
"id": TEST_ANTHROPIC_CREDENTIALS.id,
"type": TEST_ANTHROPIC_CREDENTIALS.type,
"title": TEST_ANTHROPIC_CREDENTIALS.title,
}
class ClaudeCodeBlock(Block):
"""
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
Claude Code can create files, install tools, run commands, and perform complex
coding tasks autonomously within a secure sandbox environment.
"""
# Use base template - we'll install Claude Code ourselves for latest version
DEFAULT_TEMPLATE = "base"
class Input(BlockSchemaInput):
e2b_credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"API key for the E2B platform to create the sandbox. "
"Get one on the [e2b website](https://e2b.dev/docs)"
),
)
anthropic_credentials: CredentialsMetaInput[
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
] = CredentialsField(
description=(
"API key for Anthropic to power Claude Code. "
"Get one at [Anthropic's website](https://console.anthropic.com)"
),
)
prompt: str = SchemaField(
description=(
"The task or instruction for Claude Code to execute. "
"Claude Code can create files, install packages, run commands, "
"and perform complex coding tasks."
),
placeholder="Create a hello world index.html file",
default="",
advanced=False,
)
timeout: int = SchemaField(
description=(
"Sandbox timeout in seconds. Claude Code tasks can take "
"a while, so set this appropriately for your task complexity. "
"Note: This only applies when creating a new sandbox. "
"When reconnecting to an existing sandbox via sandbox_id, "
"the original timeout is retained."
),
default=300, # 5 minutes default
advanced=True,
)
setup_commands: list[str] = SchemaField(
description=(
"Optional shell commands to run before executing Claude Code. "
"Useful for installing dependencies or setting up the environment."
),
default_factory=list,
advanced=True,
)
working_directory: str = SchemaField(
description="Working directory for Claude Code to operate in.",
default="/home/user",
advanced=True,
)
# Session/continuation support
session_id: str = SchemaField(
description=(
"Session ID to resume a previous conversation. "
"Leave empty for a new conversation. "
"Use the session_id from a previous run to continue that conversation."
),
default="",
advanced=True,
)
sandbox_id: str = SchemaField(
description=(
"Sandbox ID to reconnect to an existing sandbox. "
"Required when resuming a session (along with session_id). "
"Use the sandbox_id from a previous run where dispose_sandbox was False."
),
default="",
advanced=True,
)
conversation_history: str = SchemaField(
description=(
"Previous conversation history to continue from. "
"Use this to restore context on a fresh sandbox if the previous one timed out. "
"Pass the conversation_history output from a previous run."
),
default="",
advanced=True,
)
dispose_sandbox: bool = SchemaField(
description=(
"Whether to dispose of the sandbox immediately after execution. "
"Set to False if you want to continue the conversation later "
"(you'll need both sandbox_id and session_id from the output)."
),
default=True,
advanced=True,
)
class FileOutput(BaseModel):
"""A file extracted from the sandbox."""
path: str
relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str
content: str
class Output(BlockSchemaOutput):
response: str = SchemaField(
description="The output/response from Claude Code execution"
)
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
description=(
"List of text files created/modified by Claude Code during this execution. "
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
)
)
conversation_history: str = SchemaField(
description=(
"Full conversation history including this turn. "
"Pass this to conversation_history input to continue on a fresh sandbox "
"if the previous sandbox timed out."
)
)
session_id: str = SchemaField(
description=(
"Session ID for this conversation. "
"Pass this back along with sandbox_id to continue the conversation."
)
)
sandbox_id: Optional[str] = SchemaField(
description=(
"ID of the sandbox instance. "
"Pass this back along with session_id to continue the conversation. "
"This is None if dispose_sandbox was True (sandbox was disposed)."
),
default=None,
)
error: str = SchemaField(description="Error message if execution failed")
def __init__(self):
super().__init__(
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
description=(
"Execute tasks using Claude Code in an E2B sandbox. "
"Claude Code can create files, install tools, run commands, "
"and perform complex coding tasks autonomously."
),
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
input_schema=ClaudeCodeBlock.Input,
output_schema=ClaudeCodeBlock.Output,
test_credentials={
"e2b_credentials": TEST_E2B_CREDENTIALS,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
},
test_input={
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
"prompt": "Create a hello world HTML file",
"timeout": 300,
"setup_commands": [],
"working_directory": "/home/user",
"session_id": "",
"sandbox_id": "",
"conversation_history": "",
"dispose_sandbox": True,
},
test_output=[
("response", "Created index.html with hello world content"),
(
"files",
[
{
"path": "/home/user/index.html",
"relative_path": "index.html",
"name": "index.html",
"content": "<html>Hello World</html>",
}
],
),
(
"conversation_history",
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content",
),
("session_id", str),
("sandbox_id", None), # None because dispose_sandbox=True in test_input
],
test_mock={
"execute_claude_code": lambda *args, **kwargs: (
"Created index.html with hello world content", # response
[
ClaudeCodeBlock.FileOutput(
path="/home/user/index.html",
relative_path="index.html",
name="index.html",
content="<html>Hello World</html>",
)
], # files
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content", # conversation_history
"test-session-id", # session_id
"sandbox_id", # sandbox_id
),
},
)
async def execute_claude_code(
self,
e2b_api_key: str,
anthropic_api_key: str,
prompt: str,
timeout: int,
setup_commands: list[str],
working_directory: str,
session_id: str,
existing_sandbox_id: str,
conversation_history: str,
dispose_sandbox: bool,
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
"""
Execute Claude Code in an E2B sandbox.
Returns:
Tuple of (response, files, conversation_history, session_id, sandbox_id)
"""
# Validate that sandbox_id is provided when resuming a session
if session_id and not existing_sandbox_id:
raise ValueError(
"sandbox_id is required when resuming a session with session_id. "
"The session state is stored in the original sandbox. "
"If the sandbox has timed out, use conversation_history instead "
"to restore context on a fresh sandbox."
)
sandbox = None
sandbox_id = ""
try:
# Either reconnect to existing sandbox or create a new one
if existing_sandbox_id:
# Reconnect to existing sandbox for conversation continuation
sandbox = await BaseAsyncSandbox.connect(
sandbox_id=existing_sandbox_id,
api_key=e2b_api_key,
)
else:
# Create new sandbox
sandbox = await BaseAsyncSandbox.create(
template=self.DEFAULT_TEMPLATE,
api_key=e2b_api_key,
timeout=timeout,
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
)
# Install Claude Code from npm (ensures we get the latest version)
install_result = await sandbox.commands.run(
"npm install -g @anthropic-ai/claude-code@latest",
timeout=120, # 2 min timeout for install
)
if install_result.exit_code != 0:
raise Exception(
f"Failed to install Claude Code: {install_result.stderr}"
)
# Run any user-provided setup commands
for cmd in setup_commands:
setup_result = await sandbox.commands.run(cmd)
if setup_result.exit_code != 0:
raise Exception(
f"Setup command failed: {cmd}\n"
f"Exit code: {setup_result.exit_code}\n"
f"Stdout: {setup_result.stdout}\n"
f"Stderr: {setup_result.stderr}"
)
# Capture sandbox_id immediately after creation/connection
# so it's available for error recovery if dispose_sandbox=False
sandbox_id = sandbox.sandbox_id
# Generate or use provided session ID
current_session_id = session_id if session_id else str(uuid.uuid4())
# Build base Claude flags
base_flags = "-p --dangerously-skip-permissions --output-format json"
# Add conversation history context if provided (for fresh sandbox continuation)
history_flag = ""
if conversation_history and not session_id:
# Inject previous conversation as context via system prompt
# Use consistent escaping via _escape_prompt helper
escaped_history = self._escape_prompt(
f"Previous conversation context: {conversation_history}"
)
history_flag = f" --append-system-prompt {escaped_history}"
# Build Claude command based on whether we're resuming or starting new
# Use shlex.quote for working_directory and session IDs to prevent injection
safe_working_dir = shlex.quote(working_directory)
if session_id:
# Resuming existing session (sandbox still alive)
safe_session_id = shlex.quote(session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --resume {safe_session_id} {base_flags}"
)
else:
# New session with specific ID
safe_current_session_id = shlex.quote(current_session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
)
# Capture timestamp before running Claude Code to filter files later
# Capture timestamp 1 second in the past to avoid race condition with file creation
timestamp_result = await sandbox.commands.run(
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
)
if timestamp_result.exit_code != 0:
raise RuntimeError(
f"Failed to capture timestamp: {timestamp_result.stderr}"
)
start_timestamp = (
timestamp_result.stdout.strip() if timestamp_result.stdout else None
)
result = await sandbox.commands.run(
claude_command,
timeout=0, # No command timeout - let sandbox timeout handle it
)
# Check for command failure
if result.exit_code != 0:
error_msg = result.stderr or result.stdout or "Unknown error"
raise Exception(
f"Claude Code command failed with exit code {result.exit_code}:\n"
f"{error_msg}"
)
raw_output = result.stdout or ""
# Parse JSON output to extract response and build conversation history
response = ""
new_conversation_history = conversation_history or ""
try:
# The JSON output contains the result
output_data = json.loads(raw_output)
response = output_data.get("result", raw_output)
# Build conversation history entry
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
except json.JSONDecodeError:
# If not valid JSON, use raw output
response = raw_output
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
# Extract files created/modified during this run
files = await self._extract_files(
sandbox, working_directory, start_timestamp
)
return (
response,
files,
new_conversation_history,
current_session_id,
sandbox_id,
)
except Exception as e:
# Wrap exception with sandbox_id so caller can access/cleanup
# the preserved sandbox when dispose_sandbox=False
raise ClaudeCodeExecutionError(str(e), sandbox_id) from e
finally:
if dispose_sandbox and sandbox:
await sandbox.kill()
async def _extract_files(
self,
sandbox: BaseAsyncSandbox,
working_directory: str,
since_timestamp: str | None = None,
) -> list["ClaudeCodeBlock.FileOutput"]:
"""
Extract text files created/modified during this Claude Code execution.
Args:
sandbox: The E2B sandbox instance
working_directory: Directory to search for files
since_timestamp: ISO timestamp - only return files modified after this time
Returns:
List of FileOutput objects with path, relative_path, name, and content
"""
files: list[ClaudeCodeBlock.FileOutput] = []
# Text file extensions we can safely read as text
text_extensions = {
".txt",
".md",
".html",
".htm",
".css",
".js",
".ts",
".jsx",
".tsx",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".py",
".rb",
".php",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".swift",
".kt",
".scala",
".sh",
".bash",
".zsh",
".sql",
".graphql",
".env",
".gitignore",
".dockerfile",
"Dockerfile",
".vue",
".svelte",
".astro",
".mdx",
".rst",
".tex",
".csv",
".log",
}
try:
# List files recursively using find command
# Exclude node_modules and .git directories, but allow hidden files
# like .env and .gitignore (they're filtered by text_extensions later)
# Filter by timestamp to only get files created/modified during this run
safe_working_dir = shlex.quote(working_directory)
timestamp_filter = ""
if since_timestamp:
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
find_result = await sandbox.commands.run(
f"find {safe_working_dir} -type f "
f"{timestamp_filter}"
f"-not -path '*/node_modules/*' "
f"-not -path '*/.git/*' "
f"2>/dev/null"
)
if find_result.stdout:
for file_path in find_result.stdout.strip().split("\n"):
if not file_path:
continue
# Check if it's a text file we can read
is_text = any(
file_path.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile")
if is_text:
try:
content = await sandbox.files.read(file_path)
# Handle bytes or string
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
# Extract filename from path
file_name = file_path.split("/")[-1]
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content=content,
)
)
except Exception:
# Skip files that can't be read
pass
except Exception:
# If file extraction fails, return empty results
pass
return files
def _escape_prompt(self, prompt: str) -> str:
"""Escape the prompt for safe shell execution."""
# Use single quotes and escape any single quotes in the prompt
escaped = prompt.replace("'", "'\"'\"'")
return f"'{escaped}'"
async def run(
self,
input_data: Input,
*,
e2b_credentials: APIKeyCredentials,
anthropic_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
try:
(
response,
files,
conversation_history,
session_id,
sandbox_id,
) = await self.execute_claude_code(
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
prompt=input_data.prompt,
timeout=input_data.timeout,
setup_commands=input_data.setup_commands,
working_directory=input_data.working_directory,
session_id=input_data.session_id,
existing_sandbox_id=input_data.sandbox_id,
conversation_history=input_data.conversation_history,
dispose_sandbox=input_data.dispose_sandbox,
)
yield "response", response
# Always yield files (empty list if none) to match Output schema
yield "files", [f.model_dump() for f in files]
# Always yield conversation_history so user can restore context on fresh sandbox
yield "conversation_history", conversation_history
# Always yield session_id so user can continue conversation
yield "session_id", session_id
# Always yield sandbox_id (None if disposed) to match Output schema
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
except ClaudeCodeExecutionError as e:
yield "error", str(e)
# If sandbox was preserved (dispose_sandbox=False), yield sandbox_id
# so user can reconnect to or clean up the orphaned sandbox
if not input_data.dispose_sandbox and e.sandbox_id:
yield "sandbox_id", e.sandbox_id
except Exception as e:
yield "error", str(e)

View File

@@ -9,7 +9,7 @@ from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.execution import ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
@@ -28,6 +28,11 @@ class ReviewDecision(BaseModel):
class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def check_approval(**kwargs) -> Optional[ReviewResult]:
"""Check if there's an existing approval for this node execution."""
return await get_database_manager_async_client().check_approval(**kwargs)
@staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database."""
@@ -55,11 +60,11 @@ class HITLReviewHelper:
async def _handle_review_request(
input_data: Any,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewResult]:
@@ -69,11 +74,11 @@ class HITLReviewHelper:
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_id: ID of the node in the graph definition
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
@@ -83,15 +88,41 @@ class HITLReviewHelper:
Raises:
Exception: If review creation or status update fails
"""
# Skip review if safe mode is disabled - return auto-approved result
if not execution_context.human_in_the_loop_safe_mode:
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
# are handled by the caller:
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
# This function only handles checking for existing approvals.
# Check if this node has already been approved (normal or auto-approval)
if approval_result := await HITLReviewHelper.check_approval(
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
node_id=node_id,
user_id=user_id,
input_data=input_data,
):
logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
f"Block {block_name} skipping review for node {node_exec_id} - "
f"found existing approval"
)
# Return a new ReviewResult with the current node_exec_id but approved status
# For auto-approvals, always use current input_data
# For normal approvals, use approval_result.data unless it's None
is_auto_approval = approval_result.node_exec_id != node_exec_id
approved_data = (
input_data
if is_auto_approval
else (
approval_result.data
if approval_result.data is not None
else input_data
)
)
return ReviewResult(
data=input_data,
data=approved_data,
status=ReviewStatus.APPROVED,
message="Auto-approved (safe mode disabled)",
message=approval_result.message,
processed=True,
node_exec_id=node_exec_id,
)
@@ -103,7 +134,7 @@ class HITLReviewHelper:
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data,
message=f"Review required for {block_name} execution",
message=block_name, # Use block_name directly as the message
editable=editable,
)
@@ -129,11 +160,11 @@ class HITLReviewHelper:
async def handle_review_decision(
input_data: Any,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewDecision]:
@@ -143,11 +174,11 @@ class HITLReviewHelper:
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_id: ID of the node in the graph definition
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
@@ -158,11 +189,11 @@ class HITLReviewHelper:
review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=block_name,
editable=editable,
)

View File

@@ -97,6 +97,7 @@ class HumanInTheLoopBlock(Block):
input_data: Input,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
@@ -115,12 +116,12 @@ class HumanInTheLoopBlock(Block):
decision = await self.handle_review_decision(
input_data=input_data.data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
block_name=input_data.name, # Use user-provided name instead of block type
editable=input_data.editable,
)

View File

@@ -4,19 +4,17 @@ import logging
import re
import secrets
from abc import ABC
from enum import Enum
from enum import Enum, EnumMeta
from json import JSONDecodeError
from typing import Any, Iterable, List, Literal, Optional
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
import anthropic
import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
from pydantic_core import CoreSchema, core_schema
from pydantic import BaseModel, SecretStr
from backend.data import llm_registry
from backend.data.block import (
Block,
BlockCategory,
@@ -24,7 +22,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.llm_registry import ModelMetadata
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -69,123 +66,114 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
"""
Returns a CredentialsField for LLM providers.
The discriminator_mapping will be refreshed when the schema is generated
if it's empty, ensuring the LLM registry is loaded.
"""
# Get the mapping now - it may be empty initially, but will be refreshed
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
mapping = llm_registry.get_llm_discriminator_mapping()
return CredentialsField(
description="API key for the LLM provider.",
discriminator="model",
discriminator_mapping=mapping, # May be empty initially, refreshed later
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
)
def llm_model_schema_extra() -> dict[str, Any]:
return {"options": llm_registry.get_llm_model_schema_options()}
class ModelMetadata(NamedTuple):
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]
class LlmModelMeta(type):
"""
Metaclass for LlmModel that enables attribute-style access to dynamic models.
This allows code like `LlmModel.GPT4O` to work by converting the attribute
name to a slug format:
- GPT4O -> gpt-4o
- GPT4O_MINI -> gpt-4o-mini
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
"""
def __getattr__(cls, name: str):
# Don't intercept private/dunder attributes
if name.startswith("_"):
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
# Convert attribute name to slug format:
# 1. Lowercase: GPT4O -> gpt4o
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
slug = name.lower().replace("_", "-")
# Check for exact match in registry first (e.g., "o1" stays "o1")
registry_slugs = llm_registry.get_dynamic_model_slugs()
if slug in registry_slugs:
return cls(slug)
# If no exact match, try inserting hyphen between letter and digit
# e.g., gpt4o -> gpt-4o
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
return cls(transformed_slug)
def __iter__(cls):
"""Iterate over all models from the registry.
Yields LlmModel instances for each model in the dynamic registry.
Used by __get_pydantic_json_schema__ to build model metadata.
"""
for model in llm_registry.iter_dynamic_models():
yield cls(model.slug)
class LlmModelMeta(EnumMeta):
pass
class LlmModel(str, metaclass=LlmModelMeta):
"""
Dynamic LLM model type that accepts any model slug from the registry.
This is a string subclass (not an Enum) that allows any model slug value.
All models are managed via the LLM Registry in the database.
Usage:
model = LlmModel("gpt-4o") # Direct construction
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
model.value # Returns the slug string
model.provider # Returns the provider from registry
"""
def __new__(cls, value: str):
if isinstance(value, LlmModel):
return value
return str.__new__(cls, value)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""
Tell Pydantic how to validate LlmModel.
Accepts strings and converts them to LlmModel instances.
"""
return core_schema.no_info_after_validator_function(
cls, # The validator function (LlmModel constructor)
core_schema.str_schema(), # Accept string input
serialization=core_schema.to_string_ser_schema(), # Serialize as string
)
@property
def value(self) -> str:
"""Return the model slug (for compatibility with enum-style access)."""
return str(self)
@classmethod
def default(cls) -> "LlmModel":
"""
Get the default model from the registry.
Returns the recommended model if set, otherwise gpt-4o if available
and enabled, otherwise the first enabled model from the registry.
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
"""
from backend.data.llm_registry import get_default_model_slug
slug = get_default_model_slug()
if slug is None:
# Registry is empty (e.g., at module import time before DB connection).
# Fall back to gpt-4o for backward compatibility.
slug = "gpt-4o"
return cls(slug)
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
# Groq models
LLAMA3_3_70B = "llama-3.3-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant"
# Ollama models
OLLAMA_LLAMA3_3 = "llama3.3"
OLLAMA_LLAMA3_2 = "llama3.2"
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
@classmethod
def __get_pydantic_json_schema__(cls, schema, handler):
@@ -193,15 +181,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
llm_model_metadata = {}
for model in cls:
model_name = model.value
# Skip disabled models - only show enabled models in the picker
if not llm_registry.is_model_enabled(model_name):
continue
# Use registry directly with None check to gracefully handle
# missing metadata during startup/import before registry is populated
metadata = llm_registry.get_llm_model_metadata(model_name)
if metadata is None:
# Skip models without metadata (registry not yet populated)
continue
metadata = model.metadata
llm_model_metadata[model_name] = {
"creator": metadata.creator_name,
"creator_name": metadata.creator_name,
@@ -217,12 +197,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
@property
def metadata(self) -> ModelMetadata:
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
raise ValueError(
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
)
return MODEL_METADATA[self]
@property
def provider(self) -> str:
@@ -237,11 +212,300 @@ class LlmModel(str, metaclass=LlmModelMeta):
return self.metadata.max_output_tokens
# MODEL_METADATA removed - all models now come from the database via llm_registry
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
LlmModel.O3_MINI: ModelMetadata(
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata(
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata(
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata(
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
),
LlmModel.GPT5_1: ModelMetadata(
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
),
LlmModel.GPT5: ModelMetadata(
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_MINI: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_NANO: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_CHAT: ModelMetadata(
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
),
LlmModel.GPT41: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
),
LlmModel.GPT41_MINI: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
), # gpt-4o-mini-2024-07-18
LlmModel.GPT4O: ModelMetadata(
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
), # gpt-4o-2024-08-06
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata(
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
), # claude-3-haiku-20240307
# https://docs.aimlapi.com/api-overview/model-database/text-models
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
),
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
"aiml_api",
128000,
40000,
"Llama 3.1 Nemotron 70B Instruct",
"AI/ML",
"Nvidia",
1,
),
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
),
# https://console.groq.com/docs/models
LlmModel.LLAMA3_3_70B: ModelMetadata(
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
),
LlmModel.LLAMA3_1_8B: ModelMetadata(
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
),
# https://ollama.com/library
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65535,
"Gemini 2.5 Flash Lite Preview 06.17",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
"open_router",
1048576,
8192,
"Gemini 2.0 Flash Lite 001",
"OpenRouter",
"Google",
1,
),
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
),
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
16000,
"Sonar Deep Research",
"OpenRouter",
"Perplexity",
3,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router",
131000,
4096,
"Hermes 3 Llama 3.1 405B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router",
12288,
12288,
"Hermes 3 Llama 3.1 70B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
LlmModel.KIMI_K2: ModelMetadata(
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
"open_router",
262144,
262144,
"Qwen 3 235B A22B Thinking 2507",
"OpenRouter",
"Qwen",
1,
),
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Scout 17B 16E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Maverick 17B 128E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
}
# Default model constant for backward compatibility
# Uses the dynamic registry to get the default model
DEFAULT_LLM_MODEL = LlmModel.default()
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class ToolCall(BaseModel):
@@ -334,10 +598,7 @@ def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
# parallel tool calls. Use regex to avoid false positives like "openai/gpt-oss".
is_o_series = re.match(r"^o\d", llm_model) is not None
if is_o_series or parallel_tool_calls is None:
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.NOT_GIVEN
return parallel_tool_calls
@@ -373,98 +634,19 @@ async def llm_call(
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
# Get model metadata and check if enabled - with fallback support
# The model we'll actually use (may differ if original is disabled)
model_to_use = llm_model.value
# Check if model is in registry and if it's enabled
from backend.data.llm_registry import (
get_fallback_model_for_disabled,
get_model_info,
)
model_info = get_model_info(llm_model.value)
if model_info and not model_info.is_enabled:
# Model is disabled - try to find a fallback from the same provider
fallback = get_fallback_model_for_disabled(llm_model.value)
if fallback:
logger.warning(
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
)
model_to_use = fallback.slug
# Use fallback model's metadata
provider = fallback.metadata.provider
context_window = fallback.metadata.context_window
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
else:
# No fallback available - raise error
raise ValueError(
f"LLM model '{llm_model.value}' is disabled and no fallback model "
f"from the same provider is available. Please enable the model or "
f"select a different model in the block configuration."
)
else:
# Model is enabled or not in registry (legacy/static model)
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
except ValueError:
# Model not in cache - try refreshing the registry once if we have DB access
logger.warning(f"Model {llm_model.value} not found in registry cache")
# Try refreshing the registry if we have database access
from backend.data.db import is_connected
if is_connected():
try:
logger.info(
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
)
await llm_registry.refresh_llm_registry()
# Try again after refresh
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
logger.info(
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
)
except ValueError:
# Still not found after refresh
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry after refresh. "
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
)
except Exception as refresh_exc:
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
"Please ensure the model is added to the LLM registry via the admin UI."
) from refresh_exc
else:
# No DB access (e.g., in executor without direct DB connection)
# The registry should have been loaded on startup
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry cache. "
"The registry may need to be refreshed. Please contact support or try again later."
)
# Create effective model for model-specific parameter resolution (e.g., o-series check)
# This uses the resolved model_to_use which may differ from llm_model if fallback occurred
effective_model = LlmModel(model_to_use)
provider = llm_model.metadata.provider
context_window = llm_model.context_window
if compress_prompt_to_fit:
prompt = compress_prompt(
messages=prompt,
target_tokens=context_window // 2,
target_tokens=llm_model.context_window // 2,
lossy_ok=True,
)
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
# model_max_output already set above
model_max_output = llm_model.max_output_tokens or int(2**15)
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
@@ -475,14 +657,14 @@ async def llm_call(
response_format = None
parallel_tool_calls = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
if force_json_output:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
@@ -529,7 +711,7 @@ async def llm_call(
)
try:
resp = await client.messages.create(
model=model_to_use,
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
@@ -593,7 +775,7 @@ async def llm_call(
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if force_json_output else None
response = await client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -615,7 +797,7 @@ async def llm_call(
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = await client.generate(
model=model_to_use,
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
@@ -637,7 +819,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -645,7 +827,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -679,7 +861,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -687,7 +869,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -714,7 +896,7 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "aiml_api":
client = openai.AsyncOpenAI(
client = openai.OpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
@@ -724,8 +906,8 @@ async def llm_call(
},
)
completion = await client.chat.completions.create(
model=model_to_use,
completion = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
@@ -753,11 +935,11 @@ async def llm_call(
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -808,10 +990,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
force_json_output: bool = SchemaField(
title="Restrict LLM to pure JSON output",
@@ -874,7 +1055,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1240,10 +1421,9 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
sys_prompt: str = SchemaField(
@@ -1337,9 +1517,8 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for summarizing the text.",
json_schema_extra=llm_model_schema_extra(),
)
focus: str = SchemaField(
title="Focus",
@@ -1555,9 +1734,8 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for the conversation.",
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_tokens: int | None = SchemaField(
@@ -1594,7 +1772,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1657,10 +1835,9 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for generating the list.",
advanced=True,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_retries: int = SchemaField(
@@ -1715,7 +1892,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

View File

@@ -226,10 +226,9 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default_factory=llm.LlmModel.default,
default=llm.DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm.llm_model_schema_extra(),
)
credentials: llm.AICredentials = llm.AICredentialsField()
multiple_tool_calls: bool = SchemaField(

View File

@@ -10,13 +10,13 @@ import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
AICredentials,
AICredentialsField,
LlmModel,
ModelMetadata,
)
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.data import llm_registry
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
Returns the provider name for the model in the required format for Stagehand:
provider/model_name
"""
model_metadata = self.metadata
model_metadata = MODEL_METADATA[LlmModel(self.value)]
model_name = self.value
if len(model_name.split("/")) == 1 and not self.value.startswith(
@@ -107,23 +107,19 @@ class StagehandRecommendedLlmModel(str, Enum):
@property
def provider(self) -> str:
return self.metadata.provider
return MODEL_METADATA[LlmModel(self.value)].provider
@property
def metadata(self) -> ModelMetadata:
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
# Fallback to LlmModel enum if registry lookup fails
return LlmModel(self.value).metadata
return MODEL_METADATA[LlmModel(self.value)]
@property
def context_window(self) -> int:
return self.metadata.context_window
return MODEL_METADATA[LlmModel(self.value)].context_window
@property
def max_output_tokens(self) -> int | None:
return self.metadata.max_output_tokens
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
class StagehandObserveBlock(Block):

View File

@@ -1,7 +1,7 @@
import logging
import os
import pytest
import pytest_asyncio
from dotenv import load_dotenv
from backend.util.logging import configure_logging
@@ -19,7 +19,7 @@ if not os.getenv("PRISMA_DEBUG"):
prisma_logger.setLevel(logging.INFO)
@pytest.fixture(scope="session")
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def server():
from backend.util.test import SpinTestServer
@@ -27,7 +27,7 @@ async def server():
yield server
@pytest.fixture(scope="session", autouse=True)
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
async def graph_cleanup(server):
created_graph_ids = []
original_create_graph = server.agent_server.test_create_graph

View File

@@ -25,7 +25,6 @@ from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
@@ -144,59 +143,35 @@ class BlockInfo(BaseModel):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
@classmethod
def clear_schema_cache(cls) -> None:
"""Clear the cached JSON schema for this class."""
# Use None instead of {} because {} is truthy and would prevent regeneration
cls.cached_jsonschema = None # type: ignore
@staticmethod
def clear_all_schema_caches() -> None:
"""Clear cached JSON schemas for all BlockSchema subclasses."""
def clear_recursive(cls: type) -> None:
"""Recursively clear cache for class and all subclasses."""
if hasattr(cls, "clear_schema_cache"):
cls.clear_schema_cache()
for subclass in cls.__subclasses__():
clear_recursive(subclass)
clear_recursive(BlockSchema)
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
# Generate schema if not cached
if not cls.cached_jsonschema:
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
if cls.cached_jsonschema:
return cls.cached_jsonschema
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next(
(k for k in keys if k in obj and len(obj[k]) == 1), None
)
if one_key:
obj.update(obj[one_key][0])
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return obj
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return obj
# Always post-process to ensure LLM registry data is up-to-date
# This refreshes model options and discriminator mappings even if schema was cached
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
@@ -259,7 +234,7 @@ class BlockSchema(BaseModel):
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = None
cls.cached_jsonschema = {}
credentials_fields = cls.get_credentials_fields()
@@ -466,6 +441,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
is_sensitive_action: bool = False,
):
"""
Initialize the block with the given schema.
@@ -498,8 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.is_sensitive_action: bool = False
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -647,6 +623,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
input_data: BlockInput,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
@@ -673,11 +650,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=True,
)
@@ -896,28 +873,6 @@ def is_block_auth_configured(
async def initialize_blocks() -> None:
# Refresh LLM registry before initializing blocks so blocks can use registry data
# This ensures the registry cache is populated even in executor context
try:
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
# Only refresh if we have DB access (check if Prisma is connected)
from backend.data.db import is_connected
if is_connected():
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
logger.info("LLM registry refreshed during block initialization")
else:
logger.warning(
"Prisma not connected, skipping LLM registry refresh during block initialization"
)
except Exception as exc:
logger.warning(
"Failed to refresh LLM registry during block initialization: %s", exc
)
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs

View File

@@ -1,4 +1,3 @@
import logging
from typing import Type
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
@@ -24,18 +23,19 @@ from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
from backend.blocks.llm import (
MODEL_METADATA,
AIConversationBlock,
AIListGeneratorBlock,
AIStructuredResponseGeneratorBlock,
AITextGeneratorBlock,
AITextSummarizerBlock,
LlmModel,
)
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.data import llm_registry
from backend.data.block import Block, BlockCost, BlockCostType
from backend.integrations.credentials_store import (
aiml_api_credentials,
@@ -55,63 +55,210 @@ from backend.integrations.credentials_store import (
v0_credentials,
)
logger = logging.getLogger(__name__)
# =============== Configure the cost for each LLM Model call =============== #
PROVIDER_CREDENTIALS = {
"openai": openai_credentials,
"anthropic": anthropic_credentials,
"groq": groq_credentials,
"open_router": open_router_credentials,
"llama_api": llama_api_credentials,
"aiml_api": aiml_api_credentials,
"v0": v0_credentials,
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
LlmModel.GPT41: 2,
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.DEEPSEEK_R1_0528: 1,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
LlmModel.AMAZON_NOVA_LITE_V1: 1,
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,
LlmModel.QWEN3_CODER: 9,
# v0 by Vercel models
LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2,
LlmModel.V0_1_0_MD: 1,
}
# =============== Configure the cost for each LLM Model call =============== #
# All LLM costs now come from the database via llm_registry
LLM_COST: list[BlockCost] = []
for model in LlmModel:
if model not in MODEL_COST:
raise ValueError(f"Missing MODEL_COST for model: {model}")
def _build_llm_costs_from_registry() -> list[BlockCost]:
"""Build BlockCost list from all models in the LLM registry."""
costs: list[BlockCost] = []
for model in llm_registry.iter_dynamic_models():
for cost in model.costs:
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
if not credentials:
logger.warning(
"Skipping cost entry for %s due to unknown credentials provider %s",
model.slug,
cost.credential_provider,
)
continue
cost_filter = {
"model": model.slug,
LLM_COST = (
# Anthropic Models
[
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": credentials.id,
"provider": credentials.provider,
"type": credentials.type,
"id": anthropic_credentials.id,
"provider": anthropic_credentials.provider,
"type": anthropic_credentials.type,
},
}
costs.append(
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter=cost_filter,
cost_amount=cost.credit_cost,
)
)
return costs
def refresh_llm_costs() -> None:
"""Refresh LLM costs from the registry. All costs now come from the database."""
LLM_COST.clear()
LLM_COST.extend(_build_llm_costs_from_registry())
# Initial load will happen after registry is refreshed at startup
# Don't call refresh_llm_costs() here - it will be called after registry refresh
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "anthropic"
]
# OpenAI Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": openai_credentials.id,
"provider": openai_credentials.provider,
"type": openai_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "openai"
]
# Groq Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {"id": groq_credentials.id},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "groq"
]
# Open Router Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": open_router_credentials.id,
"provider": open_router_credentials.provider,
"type": open_router_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "open_router"
]
# Llama API Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": llama_api_credentials.id,
"provider": llama_api_credentials.provider,
"type": llama_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "llama_api"
]
# v0 by Vercel Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": v0_credentials.id,
"provider": v0_credentials.provider,
"type": v0_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "v0"
]
# AI/ML Api Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": aiml_api_credentials.id,
"provider": aiml_api_credentials.provider,
"type": aiml_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "aiml_api"
]
)
# =============== This is the exhaustive list of cost for each Block =============== #

View File

@@ -1511,10 +1511,8 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Get all model slugs from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block

View File

@@ -6,10 +6,10 @@ Handles all database operations for pending human reviews.
import asyncio
import logging
from datetime import datetime, timezone
from typing import Optional
from typing import TYPE_CHECKING, Optional
from prisma.enums import ReviewStatus
from prisma.models import PendingHumanReview
from prisma.models import AgentNodeExecution, PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput
from pydantic import BaseModel
@@ -17,8 +17,12 @@ from backend.api.features.executions.review.model import (
PendingHumanReviewModel,
SafeJsonData,
)
from backend.data.execution import get_graph_execution_meta
from backend.util.json import SafeJson
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@@ -32,6 +36,125 @@ class ReviewResult(BaseModel):
node_exec_id: str
def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str:
"""Generate the special nodeExecId key for auto-approval records."""
return f"auto_approve_{graph_exec_id}_{node_id}"
async def check_approval(
node_exec_id: str,
graph_exec_id: str,
node_id: str,
user_id: str,
input_data: SafeJsonData | None = None,
) -> Optional[ReviewResult]:
"""
Check if there's an existing approval for this node execution.
Checks both:
1. Normal approval by node_exec_id (previous run of the same node execution)
2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}"
Args:
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
node_id: ID of the node definition (not execution)
user_id: ID of the user (for data isolation)
input_data: Current input data (used for auto-approvals to avoid stale data)
Returns:
ReviewResult if approval found (either normal or auto), None otherwise
"""
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
# Check for either normal approval or auto-approval in a single query
existing_review = await PendingHumanReview.prisma().find_first(
where={
"OR": [
{"nodeExecId": node_exec_id},
{"nodeExecId": auto_approve_key},
],
"status": ReviewStatus.APPROVED,
"userId": user_id,
},
)
if existing_review:
is_auto_approval = existing_review.nodeExecId == auto_approve_key
logger.info(
f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} "
f"(exec: {node_exec_id}) in execution {graph_exec_id}"
)
# For auto-approvals, use current input_data to avoid replaying stale payload
# For normal approvals, use the stored payload (which may have been edited)
return ReviewResult(
data=(
input_data
if is_auto_approval and input_data is not None
else existing_review.payload
),
status=ReviewStatus.APPROVED,
message=(
"Auto-approved (user approved all future actions for this node)"
if is_auto_approval
else existing_review.reviewMessage or ""
),
processed=True,
node_exec_id=existing_review.nodeExecId,
)
return None
async def create_auto_approval_record(
user_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
node_id: str,
payload: SafeJsonData,
) -> None:
"""
Create an auto-approval record for a node in this execution.
This is stored as a PendingHumanReview with a special nodeExecId pattern
and status=APPROVED, so future executions of the same node can skip review.
Raises:
ValueError: If the graph execution doesn't belong to the user
"""
# Validate that the graph execution belongs to this user (defense in depth)
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise ValueError(
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
)
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
await PendingHumanReview.prisma().upsert(
where={"nodeExecId": auto_approve_key},
data={
"create": {
"nodeExecId": auto_approve_key,
"userId": user_id,
"graphExecId": graph_exec_id,
"graphId": graph_id,
"graphVersion": graph_version,
"payload": SafeJson(payload),
"instructions": "Auto-approval record",
"editable": False,
"status": ReviewStatus.APPROVED,
"processed": True,
"reviewedAt": datetime.now(timezone.utc),
},
"update": {}, # Already exists, no update needed
},
)
async def get_or_create_human_review(
user_id: str,
node_exec_id: str,
@@ -108,6 +231,87 @@ async def get_or_create_human_review(
)
async def get_pending_review_by_node_exec_id(
node_exec_id: str, user_id: str
) -> Optional["PendingHumanReviewModel"]:
"""
Get a pending review by its node execution ID.
Args:
node_exec_id: The node execution ID to look up
user_id: User ID for authorization (only returns if review belongs to this user)
Returns:
The pending review if found and belongs to user, None otherwise
"""
review = await PendingHumanReview.prisma().find_first(
where={
"nodeExecId": node_exec_id,
"userId": user_id,
"status": ReviewStatus.WAITING,
}
)
if not review:
return None
# Local import to avoid event loop conflicts in tests
from backend.data.execution import get_node_execution
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
return PendingHumanReviewModel.from_db(review, node_id=node_id)
async def get_pending_reviews_by_node_exec_ids(
node_exec_ids: list[str], user_id: str
) -> dict[str, "PendingHumanReviewModel"]:
"""
Get multiple pending reviews by their node execution IDs in a single batch query.
Args:
node_exec_ids: List of node execution IDs to look up
user_id: User ID for authorization (only returns reviews belonging to this user)
Returns:
Dictionary mapping node_exec_id -> PendingHumanReviewModel for found reviews
"""
if not node_exec_ids:
return {}
reviews = await PendingHumanReview.prisma().find_many(
where={
"nodeExecId": {"in": node_exec_ids},
"userId": user_id,
"status": ReviewStatus.WAITING,
}
)
if not reviews:
return {}
# Batch fetch all node executions to avoid N+1 queries
node_exec_ids_to_fetch = [review.nodeExecId for review in reviews]
node_execs = await AgentNodeExecution.prisma().find_many(
where={"id": {"in": node_exec_ids_to_fetch}},
include={"Node": True},
)
# Create mapping from node_exec_id to node_id
node_exec_id_to_node_id = {
node_exec.id: node_exec.agentNodeId for node_exec in node_execs
}
result = {}
for review in reviews:
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
review, node_id=node_id
)
return result
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
"""
Check if a graph execution has any pending reviews.
@@ -137,8 +341,11 @@ async def get_pending_reviews_for_user(
page_size: Number of reviews per page
Returns:
List of pending review models
List of pending review models with node_id included
"""
# Local import to avoid event loop conflicts in tests
from backend.data.execution import get_node_execution
# Calculate offset for pagination
offset = (page - 1) * page_size
@@ -149,7 +356,14 @@ async def get_pending_reviews_for_user(
take=page_size,
)
return [PendingHumanReviewModel.from_db(review) for review in reviews]
# Fetch node_id for each review from NodeExecution
result = []
for review in reviews:
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
return result
async def get_pending_reviews_for_execution(
@@ -163,8 +377,11 @@ async def get_pending_reviews_for_execution(
user_id: User ID for security validation
Returns:
List of pending review models
List of pending review models with node_id included
"""
# Local import to avoid event loop conflicts in tests
from backend.data.execution import get_node_execution
reviews = await PendingHumanReview.prisma().find_many(
where={
"userId": user_id,
@@ -174,7 +391,14 @@ async def get_pending_reviews_for_execution(
order={"createdAt": "asc"},
)
return [PendingHumanReviewModel.from_db(review) for review in reviews]
# Fetch node_id for each review from NodeExecution
result = []
for review in reviews:
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
return result
async def process_all_reviews_for_execution(
@@ -244,11 +468,19 @@ async def process_all_reviews_for_execution(
# Note: Execution resumption is now handled at the API layer after ALL reviews
# for an execution are processed (both approved and rejected)
# Return as dict for easy access
return {
review.nodeExecId: PendingHumanReviewModel.from_db(review)
for review in updated_reviews
}
# Fetch node_id for each review and return as dict for easy access
# Local import to avoid event loop conflicts in tests
from backend.data.execution import get_node_execution
result = {}
for review in updated_reviews:
node_exec = await get_node_execution(review.nodeExecId)
node_id = node_exec.node_id if node_exec else review.nodeExecId
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
review, node_id=node_id
)
return result
async def update_review_processed_status(node_exec_id: str, processed: bool) -> None:
@@ -256,3 +488,44 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) ->
await PendingHumanReview.prisma().update(
where={"nodeExecId": node_exec_id}, data={"processed": processed}
)
async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int:
"""
Cancel all pending reviews for a graph execution (e.g., when execution is stopped).
Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped.
Args:
graph_exec_id: The graph execution ID
user_id: User ID who owns the execution (for security validation)
Returns:
Number of reviews cancelled
Raises:
ValueError: If the graph execution doesn't belong to the user
"""
# Validate user ownership before cancelling reviews
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise ValueError(
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
)
result = await PendingHumanReview.prisma().update_many(
where={
"graphExecId": graph_exec_id,
"userId": user_id,
"status": ReviewStatus.WAITING,
},
data={
"status": ReviewStatus.REJECTED,
"reviewMessage": "Execution was stopped by user",
"processed": True,
"reviewedAt": datetime.now(timezone.utc),
},
)
return result

View File

@@ -36,7 +36,7 @@ def sample_db_review():
return mock_review
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_get_or_create_human_review_new(
mocker: pytest_mock.MockFixture,
sample_db_review,
@@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new(
sample_db_review.status = ReviewStatus.WAITING
sample_db_review.processed = False
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
result = await get_or_create_human_review(
user_id="test-user-123",
@@ -64,7 +64,7 @@ async def test_get_or_create_human_review_new(
assert result is None
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_get_or_create_human_review_approved(
mocker: pytest_mock.MockFixture,
sample_db_review,
@@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved(
sample_db_review.processed = False
sample_db_review.reviewMessage = "Looks good"
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
result = await get_or_create_human_review(
user_id="test-user-123",
@@ -96,7 +96,7 @@ async def test_get_or_create_human_review_approved(
assert result.message == "Looks good"
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_has_pending_reviews_for_graph_exec_true(
mocker: pytest_mock.MockFixture,
):
@@ -109,7 +109,7 @@ async def test_has_pending_reviews_for_graph_exec_true(
assert result is True
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_has_pending_reviews_for_graph_exec_false(
mocker: pytest_mock.MockFixture,
):
@@ -122,7 +122,7 @@ async def test_has_pending_reviews_for_graph_exec_false(
assert result is False
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_get_pending_reviews_for_user(
mocker: pytest_mock.MockFixture,
sample_db_review,
@@ -131,10 +131,19 @@ async def test_get_pending_reviews_for_user(
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await get_pending_reviews_for_user("test_user", page=2, page_size=10)
assert len(result) == 1
assert result[0].node_exec_id == "test_node_123"
assert result[0].node_id == "test_node_def_789"
# Verify pagination parameters
call_args = mock_find_many.return_value.find_many.call_args
@@ -142,7 +151,7 @@ async def test_get_pending_reviews_for_user(
assert call_args.kwargs["take"] == 10
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_get_pending_reviews_for_execution(
mocker: pytest_mock.MockFixture,
sample_db_review,
@@ -151,12 +160,21 @@ async def test_get_pending_reviews_for_execution(
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await get_pending_reviews_for_execution(
"test_graph_exec_456", "test-user-123"
)
assert len(result) == 1
assert result[0].graph_exec_id == "test_graph_exec_456"
assert result[0].node_id == "test_node_def_789"
# Verify it filters by execution and user
call_args = mock_find_many.return_value.find_many.call_args
@@ -166,7 +184,7 @@ async def test_get_pending_reviews_for_execution(
assert where_clause["status"] == ReviewStatus.WAITING
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_process_all_reviews_for_execution_success(
mocker: pytest_mock.MockFixture,
sample_db_review,
@@ -201,6 +219,14 @@ async def test_process_all_reviews_for_execution_success(
new=AsyncMock(return_value=[updated_review]),
)
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await process_all_reviews_for_execution(
user_id="test-user-123",
review_decisions={
@@ -211,9 +237,10 @@ async def test_process_all_reviews_for_execution_success(
assert len(result) == 1
assert "test_node_123" in result
assert result["test_node_123"].status == ReviewStatus.APPROVED
assert result["test_node_123"].node_id == "test_node_def_789"
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_process_all_reviews_for_execution_validation_errors(
mocker: pytest_mock.MockFixture,
):
@@ -233,7 +260,7 @@ async def test_process_all_reviews_for_execution_validation_errors(
)
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_process_all_reviews_edit_permission_error(
mocker: pytest_mock.MockFixture,
sample_db_review,
@@ -259,7 +286,7 @@ async def test_process_all_reviews_edit_permission_error(
)
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="function")
async def test_process_all_reviews_mixed_approval_rejection(
mocker: pytest_mock.MockFixture,
sample_db_review,
@@ -329,6 +356,14 @@ async def test_process_all_reviews_mixed_approval_rejection(
new=AsyncMock(return_value=[approved_review, rejected_review]),
)
# Mock get_node_execution to return node with node_id (async function)
mock_node_exec = Mock()
mock_node_exec.node_id = "test_node_def_789"
mocker.patch(
"backend.data.execution.get_node_execution",
new=AsyncMock(return_value=mock_node_exec),
)
result = await process_all_reviews_for_execution(
user_id="test-user-123",
review_decisions={
@@ -340,3 +375,5 @@ async def test_process_all_reviews_mixed_approval_rejection(
assert len(result) == 2
assert "test_node_123" in result
assert "test_node_456" in result
assert result["test_node_123"].node_id == "test_node_def_789"
assert result["test_node_456"].node_id == "test_node_def_789"

View File

@@ -1,72 +0,0 @@
"""
LLM Registry module for managing LLM models, providers, and costs dynamically.
This module provides a database-driven registry system for LLM models,
replacing hardcoded model configurations with a flexible admin-managed system.
"""
from backend.data.llm_registry.model import ModelMetadata
# Re-export for backwards compatibility
from backend.data.llm_registry.notifications import (
REGISTRY_REFRESH_CHANNEL,
publish_registry_refresh_notification,
subscribe_to_registry_refresh,
)
from backend.data.llm_registry.registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_default_model_slug,
get_dynamic_model_slugs,
get_fallback_model_for_disabled,
get_llm_discriminator_mapping,
get_llm_model_cost,
get_llm_model_metadata,
get_llm_model_schema_options,
get_model_info,
is_model_enabled,
iter_dynamic_models,
refresh_llm_registry,
register_static_costs,
register_static_metadata,
)
from backend.data.llm_registry.schema_utils import (
is_llm_model_field,
refresh_llm_discriminator_mapping,
refresh_llm_model_options,
update_schema_with_llm_registry,
)
__all__ = [
# Types
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Registry functions
"get_all_model_slugs_for_validation",
"get_default_model_slug",
"get_dynamic_model_slugs",
"get_fallback_model_for_disabled",
"get_llm_discriminator_mapping",
"get_llm_model_cost",
"get_llm_model_metadata",
"get_llm_model_schema_options",
"get_model_info",
"is_model_enabled",
"iter_dynamic_models",
"refresh_llm_registry",
"register_static_costs",
"register_static_metadata",
# Notifications
"REGISTRY_REFRESH_CHANNEL",
"publish_registry_refresh_notification",
"subscribe_to_registry_refresh",
# Schema utilities
"is_llm_model_field",
"refresh_llm_discriminator_mapping",
"refresh_llm_model_options",
"update_schema_with_llm_registry",
]

View File

@@ -1,25 +0,0 @@
"""Type definitions for LLM model metadata."""
from typing import Literal, NamedTuple
class ModelMetadata(NamedTuple):
"""Metadata for an LLM model.
Attributes:
provider: The provider identifier (e.g., "openai", "anthropic")
context_window: Maximum context window size in tokens
max_output_tokens: Maximum output tokens (None if unlimited)
display_name: Human-readable name for the model
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
creator_name: Name of the organization that created the model
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
"""
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]

View File

@@ -1,89 +0,0 @@
"""
Redis pub/sub notifications for LLM registry updates.
When models are added/updated/removed via the admin UI, this module
publishes notifications to Redis that all executor services subscribe to,
ensuring they refresh their registry cache in real-time.
"""
import asyncio
import logging
from typing import Any
from backend.data.redis_client import connect_async
logger = logging.getLogger(__name__)
# Redis channel name for LLM registry refresh notifications
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
"""
Publish a notification to Redis that the LLM registry has been updated.
All executor services subscribed to this channel will refresh their registry.
"""
try:
redis = await connect_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.info("Published LLM registry refresh notification to Redis")
except Exception as exc:
logger.warning(
"Failed to publish LLM registry refresh notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_refresh(
on_refresh: Any, # Async callable that takes no args
) -> None:
"""
Subscribe to Redis notifications for LLM registry updates.
This runs in a loop and processes messages as they arrive.
Args:
on_refresh: Async callable to execute when a refresh notification is received
"""
try:
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications on channel: %s",
REGISTRY_REFRESH_CHANNEL,
)
# Process messages in a loop
while True:
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info("Received LLM registry refresh notification")
try:
await on_refresh()
except Exception as exc:
logger.error(
"Error refreshing LLM registry from notification: %s",
exc,
exc_info=True,
)
except Exception as exc:
logger.warning(
"Error processing registry refresh message: %s", exc, exc_info=True
)
# Continue listening even if one message fails
await asyncio.sleep(1)
except Exception as exc:
logger.error(
"Failed to subscribe to LLM registry refresh notifications: %s",
exc,
exc_info=True,
)
raise

View File

@@ -1,388 +0,0 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Iterable
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
def _json_to_dict(value: Any) -> dict[str, Any]:
"""Convert Prisma Json type to dict, with fallback to empty dict."""
if value is None:
return {}
if isinstance(value, dict):
return value
# Prisma Json type should always be a dict at runtime
return dict(value) if value else {}
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
_static_metadata: dict[str, ModelMetadata] = {}
_static_costs: dict[str, int] = {}
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_discriminator_mapping: dict[str, str] = {}
_lock = asyncio.Lock()
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
"""Register static metadata for legacy models (deprecated)."""
_static_metadata.update({str(key): value for key, value in metadata.items()})
_refresh_cached_schema()
def register_static_costs(costs: dict[Any, int]) -> None:
"""Register static costs for legacy models (deprecated)."""
_static_costs.update({str(key): value for key, value in costs.items()})
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
options: list[dict[str, str]] = []
# Only include enabled models in the dropdown options
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
options.append(
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
)
for slug, metadata in _static_metadata.items():
if slug in _dynamic_models:
continue
options.append(
{
"label": slug,
"value": slug,
"group": metadata.provider,
"description": "",
}
)
return options
async def refresh_llm_registry() -> None:
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.debug("Found %d LLM model records in database", len(records))
except Exception as exc:
logger.error(
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
)
return
dynamic: dict[str, RegistryModel] = {}
for record in records:
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display_name = (
record.Provider.displayName if record.Provider else record.providerId
)
# Creator name: prefer Creator.name, fallback to provider display name
creator_name = (
record.Creator.name if record.Creator else provider_display_name
)
# Price tier: default to 1 (cheapest) if not set
price_tier = getattr(record, "priceTier", 1) or 1
# Clamp to valid range 1-3
price_tier = max(1, min(3, price_tier))
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
display_name=record.displayName,
provider_name=provider_display_name,
creator_name=creator_name,
price_tier=price_tier, # type: ignore[arg-type]
)
costs = tuple(
RegistryModelCost(
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=_json_to_dict(cost.metadata),
)
for cost in (record.Costs or [])
)
# Map creator if present
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
dynamic[record.slug] = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=_json_to_dict(record.capabilities),
extra_metadata=_json_to_dict(record.metadata),
provider_display_name=(
record.Provider.displayName
if record.Provider
else record.providerId
),
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
# Atomic swap - build new structures then replace references
# This ensures readers never see partially updated state
global _dynamic_models
_dynamic_models = dynamic
_refresh_cached_schema()
logger.info(
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
len(dynamic),
sum(1 for m in dynamic.values() if m.is_enabled),
sum(1 for m in dynamic.values() if not m.is_enabled),
)
def _refresh_cached_schema() -> None:
"""Refresh cached schema options and discriminator mapping."""
global _schema_options, _discriminator_mapping
# Build new structures
new_options = _build_schema_options()
new_mapping = {
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
}
for slug, metadata in _static_metadata.items():
new_mapping.setdefault(slug, metadata.provider)
# Atomic swap - replace references to ensure readers see consistent state
_schema_options = new_options
_discriminator_mapping = new_mapping
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
if slug in _dynamic_models:
return _dynamic_models[slug].metadata
return _static_metadata.get(slug)
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
"""Get model cost configuration by slug."""
if slug in _dynamic_models:
return _dynamic_models[slug].costs
cost_value = _static_costs.get(slug)
if cost_value is None:
return tuple()
return (
RegistryModelCost(
credit_cost=cost_value,
credential_provider="static",
credential_id=None,
credential_type=None,
currency=None,
metadata={},
),
)
def get_llm_model_schema_options() -> list[dict[str, str]]:
"""
Get schema options for LLM model selection dropdown.
Returns a copy of cached schema options that are refreshed when the registry is
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return list(_schema_options)
def get_llm_discriminator_mapping() -> dict[str, str]:
"""
Get discriminator mapping for LLM models.
Returns a copy of cached discriminator mapping that is refreshed when the registry
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return dict(_discriminator_mapping)
def get_dynamic_model_slugs() -> set[str]:
"""Get all dynamic model slugs from the registry."""
return set(_dynamic_models.keys())
def get_all_model_slugs_for_validation() -> set[str]:
"""
Get ALL model slugs (both enabled and disabled) for validation purposes.
This is used for JSON schema enum validation - we need to accept any known
model value (even disabled ones) so that existing graphs don't fail validation.
The actual fallback/enforcement happens at runtime in llm_call().
"""
all_slugs = set(_dynamic_models.keys())
all_slugs.update(_static_metadata.keys())
return all_slugs
def iter_dynamic_models() -> Iterable[RegistryModel]:
"""Iterate over all dynamic models in the registry."""
return tuple(_dynamic_models.values())
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
"""
Find a fallback model when the requested model is disabled.
Looks for an enabled model from the same provider. Prefers models with
similar names or capabilities if possible.
Args:
disabled_model_slug: The slug of the disabled model
Returns:
An enabled RegistryModel from the same provider, or None if no fallback found
"""
disabled_model = _dynamic_models.get(disabled_model_slug)
if not disabled_model:
return None
provider = disabled_model.metadata.provider
# Find all enabled models from the same provider
candidates = [
model
for model in _dynamic_models.values()
if model.is_enabled and model.metadata.provider == provider
]
if not candidates:
return None
# Sort by: prefer models with similar context window, then by name
candidates.sort(
key=lambda m: (
abs(m.metadata.context_window - disabled_model.metadata.context_window),
m.display_name.lower(),
)
)
return candidates[0]
def is_model_enabled(model_slug: str) -> bool:
"""Check if a model is enabled in the registry."""
model = _dynamic_models.get(model_slug)
if not model:
# Model not in registry - assume it's a static/legacy model and allow it
return True
return model.is_enabled
def get_model_info(model_slug: str) -> RegistryModel | None:
"""Get model info from the registry."""
return _dynamic_models.get(model_slug)
def get_default_model_slug() -> str | None:
"""
Get the default model slug to use for block defaults.
Returns the recommended model if set (configured via admin UI),
otherwise returns the first enabled model alphabetically.
Returns None if no models are available or enabled.
"""
# Return the recommended model if one is set and enabled
for model in _dynamic_models.values():
if model.is_recommended and model.is_enabled:
return model.slug
# No recommended model set - find first enabled model alphabetically
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
logger.warning(
"No recommended model set, using '%s' as default",
model.slug,
)
return model.slug
# No enabled models available
if _dynamic_models:
logger.error(
"No enabled models found in registry (%d models registered but all disabled)",
len(_dynamic_models),
)
else:
logger.error("No models registered in LLM registry")
return None

View File

@@ -1,130 +0,0 @@
"""
Helper utilities for LLM registry integration with block schemas.
This module handles the dynamic injection of discriminator mappings
and model options from the LLM registry into block schemas.
"""
import logging
from typing import Any
from backend.data.llm_registry.registry import (
get_all_model_slugs_for_validation,
get_default_model_slug,
get_llm_discriminator_mapping,
get_llm_model_schema_options,
)
logger = logging.getLogger(__name__)
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
"""
Check if a field is an LLM model selection field.
Returns True if the field has 'options' in json_schema_extra
(set by llm_model_schema_extra() in blocks/llm.py).
"""
if not hasattr(field_info, "json_schema_extra"):
return False
extra = field_info.json_schema_extra
if isinstance(extra, dict):
return "options" in extra
return False
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
"""
Refresh LLM model options from the registry.
Updates 'options' (for frontend dropdown) to show only enabled models,
but keeps the 'enum' (for validation) inclusive of ALL known models.
This is important because:
- Options: What users see in the dropdown (enabled models only)
- Enum: What values pass validation (all known models, including disabled)
Existing graphs may have disabled models selected - they should pass validation
and the fallback logic in llm_call() will handle using an alternative model.
"""
fresh_options = get_llm_model_schema_options()
if not fresh_options:
return
# Update options array (UI dropdown) - only enabled models
if "options" in field_schema:
field_schema["options"] = fresh_options
all_known_slugs = get_all_model_slugs_for_validation()
if all_known_slugs and "enum" in field_schema:
existing_enum = set(field_schema.get("enum", []))
combined_enum = existing_enum | all_known_slugs
field_schema["enum"] = sorted(combined_enum)
# Set the default value from the registry (gpt-4o if available, else first enabled)
# This ensures new blocks have a sensible default pre-selected
default_slug = get_default_model_slug()
if default_slug:
field_schema["default"] = default_slug
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
"""
Refresh discriminator_mapping for fields that use model-based discrimination.
The discriminator is already set when AICredentialsField() creates the field.
We only need to refresh the mapping when models are added/removed.
"""
if field_schema.get("discriminator") != "model":
return
# Always refresh the mapping to get latest models
fresh_mapping = get_llm_discriminator_mapping()
if fresh_mapping is not None:
field_schema["discriminator_mapping"] = fresh_mapping
def update_schema_with_llm_registry(
schema: dict[str, Any], model_class: type | None = None
) -> None:
"""
Update a JSON schema with current LLM registry data.
Refreshes:
1. Model options for LLM model selection fields (dropdown choices)
2. Discriminator mappings for credentials fields (model → provider)
Args:
schema: The JSON schema to update (mutated in-place)
model_class: The Pydantic model class (optional, for field introspection)
"""
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Refresh model options for LLM model fields
if model_class and hasattr(model_class, "model_fields"):
field_info = model_class.model_fields.get(field_name)
if field_info and is_llm_model_field(field_name, field_info):
try:
refresh_llm_model_options(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh LLM options for field %s: %s",
field_name,
exc,
)
# Refresh discriminator mapping for fields that use model discrimination
try:
refresh_llm_discriminator_mapping(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh discriminator mapping for field %s: %s",
field_name,
exc,
)

View File

@@ -40,7 +40,6 @@ from pydantic_core import (
)
from typing_extensions import TypedDict
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.settings import Secrets
@@ -545,9 +544,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
else:
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# Ensure LLM discriminators are populated (delegates to shared helper)
update_schema_with_llm_registry(schema, model_class)
# Do not return anything, just mutate schema in place
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
@@ -696,20 +693,16 @@ def CredentialsField(
This is enforced by the `BlockSchema` base class.
"""
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
field_schema_extra: dict[str, Any] = {}
# Always include discriminator if provided
if discriminator is not None:
field_schema_extra["discriminator"] = discriminator
# Always include discriminator_mapping when discriminator is set (even if empty initially)
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
# Include other optional fields (only if not None)
if required_scopes:
field_schema_extra["credentials_scopes"] = list(required_scopes)
if discriminator_values:
field_schema_extra["discriminator_values"] = discriminator_values
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}
# Merge any json_schema_extra passed in kwargs
if "json_schema_extra" in kwargs:

View File

@@ -50,6 +50,8 @@ from backend.data.graph import (
validate_graph_execution_permissions,
)
from backend.data.human_review import (
cancel_pending_reviews_for_execution,
check_approval,
get_or_create_human_review,
has_pending_reviews_for_graph_exec,
update_review_processed_status,
@@ -190,6 +192,8 @@ class DatabaseManager(AppService):
get_user_notification_preference = _(get_user_notification_preference)
# Human In The Loop
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
check_approval = _(check_approval)
get_or_create_human_review = _(get_or_create_human_review)
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
update_review_processed_status = _(update_review_processed_status)
@@ -313,6 +317,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
set_execution_kv_data = d.set_execution_kv_data
# Human In The Loop
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
check_approval = d.check_approval
get_or_create_human_review = d.get_or_create_human_review
update_review_processed_status = d.update_review_processed_status

View File

@@ -1,66 +0,0 @@
"""
Helper functions for LLM registry initialization in executor context.
These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import logging
from backend.data import db, llm_registry
from backend.data.block import BlockSchema, initialize_blocks
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.llm_registry import subscribe_to_registry_refresh
logger = logging.getLogger(__name__)
async def initialize_registry_for_executor() -> None:
"""
Initialize blocks and refresh LLM registry in the executor context.
This must run in the executor's event loop to have access to the database.
"""
try:
# Connect to database if not already connected
if not db.is_connected():
await db.connect()
logger.info("[GraphExecutor] Connected to database for registry refresh")
# Initialize blocks (internally refreshes LLM registry and costs)
await initialize_blocks()
logger.info("[GraphExecutor] Blocks initialized")
except Exception as exc:
logger.warning(
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
exc,
exc_info=True,
)
async def refresh_registry_on_notification() -> None:
"""Refresh LLM registry when notified via Redis pub/sub."""
try:
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_updates() -> None:
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
await subscribe_to_registry_refresh(refresh_registry_on_notification)

View File

@@ -702,20 +702,6 @@ class ExecutionProcessor:
)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
# Initialize LLM registry and subscribe to updates
from backend.executor.llm_registry_init import (
initialize_registry_for_executor,
subscribe_to_registry_updates,
)
asyncio.run_coroutine_threadsafe(
initialize_registry_for_executor(), self.node_execution_loop
)
asyncio.run_coroutine_threadsafe(
subscribe_to_registry_updates(), self.node_execution_loop
)
logger.info(f"[GraphExecutor] {self.tid} started")
@error_logged(swallow=False)

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import human_review as human_review_db
from backend.data import onboarding as onboarding_db
from backend.data import user as user_db
from backend.data.block import (
@@ -749,9 +750,27 @@ async def stop_graph_execution(
if graph_exec.status in [
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
]:
# If the graph is still on the queue, we can prevent them from being executed
# by setting the status to TERMINATED.
# If the graph is queued/incomplete/paused for review, terminate immediately
# No need to wait for executor since it's not actively running
# If graph is in REVIEW status, clean up pending reviews before terminating
if graph_exec.status == ExecutionStatus.REVIEW:
# Use human_review_db if Prisma connected, else database manager
review_db = (
human_review_db
if prisma.is_connected()
else get_database_manager_async_client()
)
# Mark all pending reviews as rejected/cancelled
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
graph_exec_id, user_id
)
logger.info(
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
)
graph_exec.status = ExecutionStatus.TERMINATED
await asyncio.gather(
@@ -887,9 +906,28 @@ async def add_graph_execution(
nodes_to_skip=nodes_to_skip,
execution_context=execution_context,
)
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
logger.info(f"Queueing execution {graph_exec.id}")
# Update execution status to QUEUED BEFORE publishing to prevent race condition
# where two concurrent requests could both publish the same execution
updated_exec = await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.QUEUED,
)
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
# If another request already updated the status, this execution will not be QUEUED
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
logger.warning(
f"Skipping queue publish for execution {graph_exec.id} - "
f"status update failed or execution already queued by another request"
)
return graph_exec
graph_exec.status = ExecutionStatus.QUEUED
# Publish to execution queue for executor to pick up
# This happens AFTER status update to ensure only one request publishes
exec_queue = await get_async_execution_queue()
await exec_queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
@@ -897,13 +935,6 @@ async def add_graph_execution(
exchange=GRAPH_EXECUTION_EXCHANGE,
)
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
# Update execution status to QUEUED
graph_exec.status = ExecutionStatus.QUEUED
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=graph_exec.status,
)
except BaseException as e:
err = str(e) or type(e).__name__
if not graph_exec:

View File

@@ -4,6 +4,7 @@ import pytest
from pytest_mock import MockerFixture
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
from backend.data.execution import ExecutionStatus
from backend.util.mock import MockObject
@@ -346,6 +347,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
# Mock the queue and event bus
@@ -611,6 +613,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}
@@ -670,3 +673,232 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
# Verify nodes_to_skip was passed to to_graph_execution_entry
assert "nodes_to_skip" in captured_kwargs
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
@pytest.mark.asyncio
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
mocker: MockerFixture,
):
"""Test that stopping an execution in REVIEW status cancels pending reviews."""
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
from backend.executor.utils import stop_graph_execution
user_id = "test-user"
graph_exec_id = "test-exec-123"
# Mock graph execution in REVIEW status
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_graph_exec.id = graph_exec_id
mock_graph_exec.status = ExecutionStatus.REVIEW
# Mock dependencies
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_queue_client = mocker.AsyncMock()
mock_get_queue.return_value = mock_queue_client
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_prisma.is_connected.return_value = True
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
return_value=2 # 2 reviews cancelled
)
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
mock_get_event_bus.return_value = mock_event_bus
mock_get_child_executions = mocker.patch(
"backend.executor.utils._get_child_executions"
)
mock_get_child_executions.return_value = [] # No children
# Call stop_graph_execution with timeout to allow status check
await stop_graph_execution(
user_id=user_id,
graph_exec_id=graph_exec_id,
wait_timeout=1.0, # Wait to allow status check
cascade=True,
)
# Verify pending reviews were cancelled
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
graph_exec_id, user_id
)
# Verify execution status was updated to TERMINATED
mock_execution_db.update_graph_execution_stats.assert_called_once()
call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1]
assert call_kwargs["graph_exec_id"] == graph_exec_id
assert call_kwargs["status"] == ExecutionStatus.TERMINATED
@pytest.mark.asyncio
async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected(
mocker: MockerFixture,
):
"""Test that stop uses database manager when Prisma is not connected."""
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
from backend.executor.utils import stop_graph_execution
user_id = "test-user"
graph_exec_id = "test-exec-456"
# Mock graph execution in REVIEW status
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_graph_exec.id = graph_exec_id
mock_graph_exec.status = ExecutionStatus.REVIEW
# Mock dependencies
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_queue_client = mocker.AsyncMock()
mock_get_queue.return_value = mock_queue_client
# Prisma is NOT connected
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_prisma.is_connected.return_value = False
# Mock database manager client
mock_get_db_manager = mocker.patch(
"backend.executor.utils.get_database_manager_async_client"
)
mock_db_manager = mocker.AsyncMock()
mock_db_manager.get_graph_execution_meta = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock(
return_value=3 # 3 reviews cancelled
)
mock_db_manager.update_graph_execution_stats = mocker.AsyncMock()
mock_get_db_manager.return_value = mock_db_manager
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
mock_get_event_bus.return_value = mock_event_bus
mock_get_child_executions = mocker.patch(
"backend.executor.utils._get_child_executions"
)
mock_get_child_executions.return_value = [] # No children
# Call stop_graph_execution with timeout
await stop_graph_execution(
user_id=user_id,
graph_exec_id=graph_exec_id,
wait_timeout=1.0,
cascade=True,
)
# Verify database manager was used for cancel_pending_reviews
mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with(
graph_exec_id, user_id
)
# Verify execution status was updated via database manager
mock_db_manager.update_graph_execution_stats.assert_called_once()
@pytest.mark.asyncio
async def test_stop_graph_execution_cascades_to_child_with_reviews(
mocker: MockerFixture,
):
"""Test that stopping parent execution cascades to children and cancels their reviews."""
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
from backend.executor.utils import stop_graph_execution
user_id = "test-user"
parent_exec_id = "parent-exec"
child_exec_id = "child-exec"
# Mock parent execution in RUNNING status
mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_parent_exec.id = parent_exec_id
mock_parent_exec.status = ExecutionStatus.RUNNING
# Mock child execution in REVIEW status
mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta)
mock_child_exec.id = child_exec_id
mock_child_exec.status = ExecutionStatus.REVIEW
# Mock dependencies
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_queue_client = mocker.AsyncMock()
mock_get_queue.return_value = mock_queue_client
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_prisma.is_connected.return_value = True
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
return_value=1 # 1 child review cancelled
)
# Mock execution_db to return different status based on which execution is queried
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
# Track call count to simulate status transition
call_count = {"count": 0}
async def get_exec_meta_side_effect(execution_id, user_id):
call_count["count"] += 1
if execution_id == parent_exec_id:
# After a few calls (child processing happens), transition parent to TERMINATED
# This simulates the executor service processing the stop request
if call_count["count"] > 3:
mock_parent_exec.status = ExecutionStatus.TERMINATED
return mock_parent_exec
elif execution_id == child_exec_id:
return mock_child_exec
return None
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
side_effect=get_exec_meta_side_effect
)
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_event_bus = mocker.MagicMock()
mock_event_bus.publish = mocker.AsyncMock()
mock_get_event_bus.return_value = mock_event_bus
# Mock _get_child_executions to return the child
mock_get_child_executions = mocker.patch(
"backend.executor.utils._get_child_executions"
)
def get_children_side_effect(parent_id):
if parent_id == parent_exec_id:
return [mock_child_exec]
return []
mock_get_child_executions.side_effect = get_children_side_effect
# Call stop_graph_execution on parent with cascade=True
await stop_graph_execution(
user_id=user_id,
graph_exec_id=parent_exec_id,
wait_timeout=1.0,
cascade=True,
)
# Verify child reviews were cancelled
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
child_exec_id, user_id
)
# Verify both parent and child status updates
assert mock_execution_db.update_graph_execution_stats.call_count >= 1

View File

@@ -1,935 +0,0 @@
from __future__ import annotations
from typing import Any, Iterable, Sequence, cast
import prisma
import prisma.models
from backend.data.db import transaction
from backend.server.v2.llm import model as llm_model
from backend.util.models import Pagination
def _json_dict(value: Any | None) -> dict[str, Any]:
if not value:
return {}
if isinstance(value, dict):
return value
return {}
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
return llm_model.LlmModelCost(
id=record.id,
unit=record.unit,
credit_cost=record.creditCost,
credential_provider=record.credentialProvider,
credential_id=record.credentialId,
credential_type=record.credentialType,
currency=record.currency,
metadata=_json_dict(record.metadata),
)
def _map_creator(
record: prisma.models.LlmModelCreator,
) -> llm_model.LlmModelCreator:
return llm_model.LlmModelCreator(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
website_url=record.websiteUrl,
logo_url=record.logoUrl,
metadata=_json_dict(record.metadata),
)
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
costs = []
if record.Costs:
costs = [_map_cost(cost) for cost in record.Costs]
creator = None
if hasattr(record, "Creator") and record.Creator:
creator = _map_creator(record.Creator)
return llm_model.LlmModel(
id=record.id,
slug=record.slug,
display_name=record.displayName,
description=record.description,
provider_id=record.providerId,
creator_id=record.creatorId,
creator=creator,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
capabilities=_json_dict(record.capabilities),
metadata=_json_dict(record.metadata),
costs=costs,
)
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
models: list[llm_model.LlmModel] = []
if record.Models:
models = [_map_model(model) for model in record.Models]
return llm_model.LlmProvider(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
default_credential_provider=record.defaultCredentialProvider,
default_credential_id=record.defaultCredentialId,
default_credential_type=record.defaultCredentialType,
supports_tools=record.supportsTools,
supports_json_output=record.supportsJsonOutput,
supports_reasoning=record.supportsReasoning,
supports_parallel_tool=record.supportsParallelTool,
metadata=_json_dict(record.metadata),
models=models,
)
async def list_providers(
include_models: bool = True, enabled_only: bool = False
) -> list[llm_model.LlmProvider]:
"""
List all LLM providers.
Args:
include_models: Whether to include models for each provider
enabled_only: If True, only include enabled models (for public routes)
"""
include: Any = None
if include_models:
model_where = {"isEnabled": True} if enabled_only else None
include = {
"Models": {
"include": {"Costs": True, "Creator": True},
"where": model_where,
}
}
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
return [_map_provider(record) for record in records]
async def upsert_provider(
request: llm_model.UpsertLlmProviderRequest,
provider_id: str | None = None,
) -> llm_model.LlmProvider:
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"defaultCredentialProvider": request.default_credential_provider,
"defaultCredentialId": request.default_credential_id,
"defaultCredentialType": request.default_credential_type,
"supportsTools": request.supports_tools,
"supportsJsonOutput": request.supports_json_output,
"supportsReasoning": request.supports_reasoning,
"supportsParallelTool": request.supports_parallel_tool,
"metadata": prisma.Json(request.metadata or {}),
}
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
if provider_id:
record = await prisma.models.LlmProvider.prisma().update(
where={"id": provider_id},
data=data,
include=include,
)
else:
record = await prisma.models.LlmProvider.prisma().create(
data=data,
include=include,
)
if record is None:
raise ValueError("Failed to create/update provider")
return _map_provider(record)
async def delete_provider(provider_id: str) -> bool:
"""
Delete an LLM provider.
A provider can only be deleted if it has no associated models.
Due to onDelete: Restrict on LlmModel.Provider, the database will
block deletion if models exist.
Args:
provider_id: UUID of the provider to delete
Returns:
True if deleted successfully
Raises:
ValueError: If provider not found or has associated models
"""
# Check if provider exists
provider = await prisma.models.LlmProvider.prisma().find_unique(
where={"id": provider_id},
include={"Models": True},
)
if not provider:
raise ValueError(f"Provider with id '{provider_id}' not found")
# Check if provider has any models
model_count = len(provider.Models) if provider.Models else 0
if model_count > 0:
raise ValueError(
f"Cannot delete provider '{provider.displayName}' because it has "
f"{model_count} model(s). Delete all models first."
)
# Safe to delete
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
return True
async def list_models(
provider_id: str | None = None,
enabled_only: bool = False,
page: int = 1,
page_size: int = 50,
) -> llm_model.LlmModelsResponse:
"""
List LLM models with pagination.
Args:
provider_id: Optional filter by provider ID
enabled_only: If True, only return enabled models (for public routes)
page: Page number (1-indexed)
page_size: Number of models per page
"""
where: Any = {}
if provider_id:
where["providerId"] = provider_id
if enabled_only:
where["isEnabled"] = True
# Get total count for pagination
total_items = await prisma.models.LlmModel.prisma().count(
where=where if where else None
)
# Calculate pagination
skip = (page - 1) * page_size
total_pages = (total_items + page_size - 1) // page_size if total_items > 0 else 0
records = await prisma.models.LlmModel.prisma().find_many(
where=where if where else None,
include={"Costs": True, "Creator": True},
skip=skip,
take=page_size,
)
models = [_map_model(record) for record in records]
return llm_model.LlmModelsResponse(
models=models,
pagination=Pagination(
total_items=total_items,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
def _cost_create_payload(
costs: Sequence[llm_model.LlmModelCostInput],
) -> dict[str, Iterable[dict[str, Any]]]:
create_items = []
for cost in costs:
item: dict[str, Any] = {
"unit": cost.unit,
"creditCost": cost.credit_cost,
"credentialProvider": cost.credential_provider,
}
# Only include optional fields if they have values
if cost.credential_id:
item["credentialId"] = cost.credential_id
if cost.credential_type:
item["credentialType"] = cost.credential_type
if cost.currency:
item["currency"] = cost.currency
# Handle metadata - use Prisma Json type
if cost.metadata is not None and cost.metadata != {}:
item["metadata"] = prisma.Json(cost.metadata)
create_items.append(item)
return {"create": create_items}
async def create_model(
request: llm_model.CreateLlmModelRequest,
) -> llm_model.LlmModel:
data: Any = {
"slug": request.slug,
"displayName": request.display_name,
"description": request.description,
"Provider": {"connect": {"id": request.provider_id}},
"contextWindow": request.context_window,
"maxOutputTokens": request.max_output_tokens,
"isEnabled": request.is_enabled,
"capabilities": prisma.Json(request.capabilities or {}),
"metadata": prisma.Json(request.metadata or {}),
"Costs": _cost_create_payload(request.costs),
}
if request.creator_id:
data["Creator"] = {"connect": {"id": request.creator_id}}
record = await prisma.models.LlmModel.prisma().create(
data=data,
include={"Costs": True, "Creator": True, "Provider": True},
)
return _map_model(record)
async def update_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
) -> llm_model.LlmModel:
# Build scalar field updates (non-relation fields)
scalar_data: Any = {}
if request.display_name is not None:
scalar_data["displayName"] = request.display_name
if request.description is not None:
scalar_data["description"] = request.description
if request.context_window is not None:
scalar_data["contextWindow"] = request.context_window
if request.max_output_tokens is not None:
scalar_data["maxOutputTokens"] = request.max_output_tokens
if request.is_enabled is not None:
scalar_data["isEnabled"] = request.is_enabled
if request.capabilities is not None:
scalar_data["capabilities"] = request.capabilities
if request.metadata is not None:
scalar_data["metadata"] = request.metadata
# Foreign keys can be updated directly as scalar fields
if request.provider_id is not None:
scalar_data["providerId"] = request.provider_id
if request.creator_id is not None:
# Empty string means remove the creator
scalar_data["creatorId"] = request.creator_id if request.creator_id else None
# If we have costs to update, we need to handle them separately
# because nested writes have different constraints
if request.costs is not None:
# Wrap cost replacement in a transaction for atomicity
async with transaction() as tx:
# First update scalar fields
if scalar_data:
await tx.llmmodel.update(
where={"id": model_id},
data=scalar_data,
)
# Then handle costs: delete existing and create new
await tx.llmmodelcost.delete_many(where={"llmModelId": model_id})
if request.costs:
cost_payload = _cost_create_payload(request.costs)
for cost_item in cost_payload["create"]:
cost_item["llmModelId"] = model_id
await tx.llmmodelcost.create(data=cast(Any, cost_item))
# Fetch the updated record (outside transaction)
record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
else:
# No costs update - simple update
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data=scalar_data,
include={"Costs": True, "Creator": True},
)
if not record:
raise ValueError(f"Model with id '{model_id}' not found")
return _map_model(record)
async def toggle_model(
model_id: str,
is_enabled: bool,
migrate_to_slug: str | None = None,
migration_reason: str | None = None,
custom_credit_cost: int | None = None,
) -> llm_model.ToggleLlmModelResponse:
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
Args:
model_id: UUID of the model to toggle
is_enabled: New enabled status
migrate_to_slug: If disabling and this is provided, migrate all workflows
using this model to the specified replacement model
migration_reason: Optional reason for the migration (e.g., "Provider outage")
custom_credit_cost: Optional custom pricing override for migrated workflows.
When set, the billing system should use this cost instead
of the target model's cost for affected nodes.
Returns:
ToggleLlmModelResponse with the updated model and optional migration stats
"""
import json
# Get the model being toggled
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
nodes_migrated = 0
migration_id: str | None = None
# If disabling with migration, perform migration first
if not is_enabled and migrate_to_slug:
# Validate replacement model exists and is enabled
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migrate_to_slug}
)
if not replacement:
raise ValueError(f"Replacement model '{migrate_to_slug}' not found")
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{migrate_to_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# Perform all operations atomically within a single transaction
# This ensures no nodes are missed between query and update
async with transaction() as tx:
# Get the IDs of nodes that will be migrated (inside transaction for consistency)
node_ids_result = await tx.query_raw(
"""
SELECT id
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
FOR UPDATE
""",
model.slug,
)
migrated_node_ids = (
[row["id"] for row in node_ids_result] if node_ids_result else []
)
nodes_migrated = len(migrated_node_ids)
if nodes_migrated > 0:
# Update by IDs to ensure we only update the exact nodes we queried
# Use JSON array and jsonb_array_elements_text for safe parameterization
node_ids_json = json.dumps(migrated_node_ids)
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
""",
migrate_to_slug,
node_ids_json,
)
record = await tx.llmmodel.update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
# Create migration record for revert capability
if nodes_migrated > 0:
migration_data: Any = {
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
migration_record = await tx.llmmodelmigration.create(
data=migration_data
)
migration_id = migration_record.id
else:
# Simple toggle without migration
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
if record is None:
raise ValueError(f"Model with id '{model_id}' not found")
return llm_model.ToggleLlmModelResponse(
model=_map_model(record),
nodes_migrated=nodes_migrated,
migrated_to_slug=migrate_to_slug if nodes_migrated > 0 else None,
migration_id=migration_id,
)
async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
"""Get usage count for a model."""
import prisma as prisma_module
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
model.slug,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
async def delete_model(
model_id: str, replacement_model_slug: str | None = None
) -> llm_model.DeleteLlmModelResponse:
"""
Delete a model and optionally migrate all AgentNodes using it to a replacement model.
This performs an atomic operation within a database transaction:
1. Validates the model exists
2. Counts affected nodes
3. If nodes exist, validates replacement model and migrates them
4. Deletes the LlmModel record (CASCADE deletes costs)
Args:
model_id: UUID of the model to delete
replacement_model_slug: Slug of the model to migrate to (required only if nodes use this model)
Returns:
DeleteLlmModelResponse with migration stats
Raises:
ValueError: If model not found, nodes exist but no replacement provided,
replacement not found, or replacement is disabled
"""
# 1. Get the model being deleted (validation - outside transaction)
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
deleted_slug = model.slug
deleted_display_name = model.displayName
# 2. Count affected nodes first to determine if replacement is needed
import prisma as prisma_module
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
deleted_slug,
)
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
# 3. Validate replacement model only if there are nodes to migrate
if nodes_to_migrate > 0:
if not replacement_model_slug:
raise ValueError(
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
f"are using it. Please provide a replacement_model_slug to migrate them."
)
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": replacement_model_slug}
)
if not replacement:
raise ValueError(f"Replacement model '{replacement_model_slug}' not found")
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{replacement_model_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# 4. Perform migration (if needed) and deletion atomically within a transaction
async with transaction() as tx:
# Migrate all AgentNode.constantInput->model to replacement
if nodes_to_migrate > 0 and replacement_model_slug:
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
# Delete the model (CASCADE will delete costs automatically)
await tx.llmmodel.delete(where={"id": model_id})
# Build appropriate message based on whether migration happened
if nodes_to_migrate > 0:
message = (
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
f"and migrated {nodes_to_migrate} workflow node(s) to '{replacement_model_slug}'."
)
else:
message = (
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}). "
f"No workflows were using this model."
)
return llm_model.DeleteLlmModelResponse(
deleted_model_slug=deleted_slug,
deleted_model_display_name=deleted_display_name,
replacement_model_slug=replacement_model_slug,
nodes_migrated=nodes_to_migrate,
message=message,
)
def _map_migration(
record: prisma.models.LlmModelMigration,
) -> llm_model.LlmModelMigration:
return llm_model.LlmModelMigration(
id=record.id,
source_model_slug=record.sourceModelSlug,
target_model_slug=record.targetModelSlug,
reason=record.reason,
node_count=record.nodeCount,
custom_credit_cost=record.customCreditCost,
is_reverted=record.isReverted,
created_at=record.createdAt.isoformat(),
reverted_at=record.revertedAt.isoformat() if record.revertedAt else None,
)
async def list_migrations(
include_reverted: bool = False,
) -> list[llm_model.LlmModelMigration]:
"""
List model migrations, optionally including reverted ones.
Args:
include_reverted: If True, include reverted migrations. Default is False.
Returns:
List of LlmModelMigration records
"""
where: Any = None if include_reverted else {"isReverted": False}
records = await prisma.models.LlmModelMigration.prisma().find_many(
where=where,
order={"createdAt": "desc"},
)
return [_map_migration(record) for record in records]
async def get_migration(migration_id: str) -> llm_model.LlmModelMigration | None:
"""Get a specific migration by ID."""
record = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
return _map_migration(record) if record else None
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> llm_model.RevertMigrationResponse:
"""
Revert a model migration, restoring affected nodes to their original model.
This only reverts the specific nodes that were migrated, not all nodes
currently using the target model.
Args:
migration_id: UUID of the migration to revert
re_enable_source_model: Whether to re-enable the source model if it's disabled
Returns:
RevertMigrationResponse with revert stats
Raises:
ValueError: If migration not found, already reverted, or source model not available
"""
import json
from datetime import datetime, timezone
# Get the migration record
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
if not migration:
raise ValueError(f"Migration with id '{migration_id}' not found")
if migration.isReverted:
raise ValueError(
f"Migration '{migration_id}' has already been reverted "
f"on {migration.revertedAt.isoformat() if migration.revertedAt else 'unknown date'}"
)
# Check if source model exists
source_model = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migration.sourceModelSlug}
)
if not source_model:
raise ValueError(
f"Source model '{migration.sourceModelSlug}' no longer exists. "
f"Cannot revert migration."
)
# Get the migrated node IDs (Prisma auto-parses JSONB to list)
migrated_node_ids: list[str] = (
migration.migratedNodeIds
if isinstance(migration.migratedNodeIds, list)
else json.loads(migration.migratedNodeIds) # type: ignore
)
if not migrated_node_ids:
raise ValueError("No nodes to revert in this migration")
# Track if we need to re-enable the source model
source_model_was_disabled = not source_model.isEnabled
should_re_enable = source_model_was_disabled and re_enable_source_model
source_model_re_enabled = False
# Perform revert atomically
async with transaction() as tx:
# Re-enable the source model if requested and it was disabled
if should_re_enable:
await tx.llmmodel.update(
where={"id": source_model.id},
data={"isEnabled": True},
)
source_model_re_enabled = True
# Update only the specific nodes that were migrated
# We need to check that they still have the target model (haven't been changed since)
# Use a single batch update for efficiency
# Use JSON array and jsonb_array_elements_text for safe parameterization
node_ids_json = json.dumps(migrated_node_ids)
result = await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
AND "constantInput"::jsonb->>'model' = $3
""",
migration.sourceModelSlug,
node_ids_json,
migration.targetModelSlug,
)
nodes_reverted = result if result else 0
# Mark migration as reverted
await tx.llmmodelmigration.update(
where={"id": migration_id},
data={
"isReverted": True,
"revertedAt": datetime.now(timezone.utc),
},
)
# Calculate nodes that were already changed since migration
nodes_already_changed = len(migrated_node_ids) - nodes_reverted
# Build appropriate message
message_parts = [
f"Successfully reverted migration: {nodes_reverted} node(s) restored "
f"from '{migration.targetModelSlug}' to '{migration.sourceModelSlug}'."
]
if nodes_already_changed > 0:
message_parts.append(
f" {nodes_already_changed} node(s) were already changed and not reverted."
)
if source_model_re_enabled:
message_parts.append(
f" Model '{migration.sourceModelSlug}' has been re-enabled."
)
return llm_model.RevertMigrationResponse(
migration_id=migration_id,
source_model_slug=migration.sourceModelSlug,
target_model_slug=migration.targetModelSlug,
nodes_reverted=nodes_reverted,
nodes_already_changed=nodes_already_changed,
source_model_re_enabled=source_model_re_enabled,
message="".join(message_parts),
)
# ============================================================================
# Creator CRUD operations
# ============================================================================
async def list_creators() -> list[llm_model.LlmModelCreator]:
"""List all LLM model creators."""
records = await prisma.models.LlmModelCreator.prisma().find_many(
order={"displayName": "asc"}
)
return [_map_creator(record) for record in records]
async def get_creator(creator_id: str) -> llm_model.LlmModelCreator | None:
"""Get a specific creator by ID."""
record = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
return _map_creator(record) if record else None
async def upsert_creator(
request: llm_model.UpsertLlmCreatorRequest,
creator_id: str | None = None,
) -> llm_model.LlmModelCreator:
"""Create or update a model creator."""
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"websiteUrl": request.website_url,
"logoUrl": request.logo_url,
"metadata": prisma.Json(request.metadata or {}),
}
if creator_id:
record = await prisma.models.LlmModelCreator.prisma().update(
where={"id": creator_id},
data=data,
)
else:
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
if record is None:
raise ValueError("Failed to create/update creator")
return _map_creator(record)
async def delete_creator(creator_id: str) -> bool:
"""
Delete a model creator.
This will set creatorId to NULL on all associated models (due to onDelete: SetNull).
Args:
creator_id: UUID of the creator to delete
Returns:
True if deleted successfully
Raises:
ValueError: If creator not found
"""
creator = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
if not creator:
raise ValueError(f"Creator with id '{creator_id}' not found")
await prisma.models.LlmModelCreator.prisma().delete(where={"id": creator_id})
return True
async def get_recommended_model() -> llm_model.LlmModel | None:
"""
Get the currently recommended LLM model.
Returns:
The recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
include={"Costs": True, "Creator": True},
)
return _map_model(record) if record else None
async def set_recommended_model(
model_id: str,
) -> tuple[llm_model.LlmModel, str | None]:
"""
Set a model as the recommended model.
This will clear the isRecommended flag from any other model and set it
on the specified model. The model must be enabled.
Args:
model_id: UUID of the model to set as recommended
Returns:
Tuple of (the updated model, previous recommended model slug or None)
Raises:
ValueError: If model not found or not enabled
"""
# First, verify the model exists and is enabled
target_model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}
)
if not target_model:
raise ValueError(f"Model with id '{model_id}' not found")
if not target_model.isEnabled:
raise ValueError(
f"Cannot set disabled model '{target_model.slug}' as recommended"
)
# Get the current recommended model (if any)
current_recommended = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True}
)
previous_slug = current_recommended.slug if current_recommended else None
# Use a transaction to ensure atomicity
async with transaction() as tx:
# Clear isRecommended from all models
await tx.llmmodel.update_many(
where={"isRecommended": True},
data={"isRecommended": False},
)
# Set the new recommended model
await tx.llmmodel.update(
where={"id": model_id},
data={"isRecommended": True},
)
# Fetch and return the updated model
updated_record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
if not updated_record:
raise ValueError("Failed to fetch updated model")
return _map_model(updated_record), previous_slug
async def get_recommended_model_slug() -> str | None:
"""
Get the slug of the currently recommended LLM model.
Returns:
The slug of the recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
)
return record.slug if record else None

View File

@@ -1,235 +0,0 @@
from __future__ import annotations
import re
from datetime import datetime
from typing import Any, Optional
import prisma.enums
import pydantic
from backend.util.models import Pagination
# Pattern for valid model slugs: alphanumeric start, then alphanumeric, dots, underscores, slashes, hyphens
SLUG_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$")
class LlmModelCost(pydantic.BaseModel):
id: str
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = None
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model (e.g., OpenAI, Meta)."""
id: str
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModel(pydantic.BaseModel):
id: str
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
creator: Optional[LlmModelCreator] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
id: str
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = None
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmProvidersResponse(pydantic.BaseModel):
providers: list[LlmProvider]
class LlmModelsResponse(pydantic.BaseModel):
models: list[LlmModel]
pagination: Optional[Pagination] = None
class LlmCreatorsResponse(pydantic.BaseModel):
creators: list[LlmModelCreator]
class UpsertLlmProviderRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = "api_key"
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class UpsertLlmCreatorRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCostInput(pydantic.BaseModel):
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = "api_key"
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class CreateLlmModelRequest(pydantic.BaseModel):
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCostInput]
@pydantic.field_validator("slug")
@classmethod
def validate_slug(cls, v: str) -> str:
if not v or len(v) > 100:
raise ValueError("Slug must be 1-100 characters")
if not SLUG_PATTERN.match(v):
raise ValueError(
"Slug must start with alphanumeric and contain only "
"alphanumeric characters, dots, underscores, slashes, or hyphens"
)
return v
class UpdateLlmModelRequest(pydantic.BaseModel):
display_name: Optional[str] = None
description: Optional[str] = None
context_window: Optional[int] = None
max_output_tokens: Optional[int] = None
is_enabled: Optional[bool] = None
capabilities: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
provider_id: Optional[str] = None
creator_id: Optional[str] = None
costs: Optional[list[LlmModelCostInput]] = None
class ToggleLlmModelRequest(pydantic.BaseModel):
is_enabled: bool
migrate_to_slug: Optional[str] = None
migration_reason: Optional[str] = None # e.g., "Provider outage"
# Custom pricing override for migrated workflows. When set, billing should use
# this cost instead of the target model's cost for affected nodes.
# See LlmModelMigration in schema.prisma for full documentation.
custom_credit_cost: Optional[int] = None
class ToggleLlmModelResponse(pydantic.BaseModel):
model: LlmModel
nodes_migrated: int = 0
migrated_to_slug: Optional[str] = None
migration_id: Optional[str] = None # ID of the migration record for revert
class DeleteLlmModelResponse(pydantic.BaseModel):
deleted_model_slug: str
deleted_model_display_name: str
replacement_model_slug: Optional[str] = None
nodes_migrated: int
message: str
class LlmModelUsageResponse(pydantic.BaseModel):
model_slug: str
node_count: int
# Migration tracking models
class LlmModelMigration(pydantic.BaseModel):
id: str
source_model_slug: str
target_model_slug: str
reason: Optional[str] = None
node_count: int
# Custom pricing override - billing should use this instead of target model's cost
custom_credit_cost: Optional[int] = None
is_reverted: bool = False
created_at: datetime
reverted_at: Optional[datetime] = None
class LlmMigrationsResponse(pydantic.BaseModel):
migrations: list[LlmModelMigration]
class RevertMigrationRequest(pydantic.BaseModel):
re_enable_source_model: bool = (
True # Whether to re-enable the source model if disabled
)
class RevertMigrationResponse(pydantic.BaseModel):
migration_id: str
source_model_slug: str
target_model_slug: str
nodes_reverted: int
nodes_already_changed: int = (
0 # Nodes that were modified since migration (not reverted)
)
source_model_re_enabled: bool = False # Whether the source model was re-enabled
message: str
class SetRecommendedModelRequest(pydantic.BaseModel):
model_id: str
class SetRecommendedModelResponse(pydantic.BaseModel):
model: LlmModel
previous_recommended_slug: Optional[str] = None
message: str
class RecommendedModelResponse(pydantic.BaseModel):
model: Optional[LlmModel] = None
slug: Optional[str] = None

View File

@@ -1,29 +0,0 @@
import autogpt_libs.auth
import fastapi
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models(
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = fastapi.Query(
default=50, ge=1, le=100, description="Number of models per page"
),
):
"""List all enabled LLM models available to users."""
return await llm_db.list_models(enabled_only=True, page=page, page_size=page_size)
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""List all LLM providers with their enabled models."""
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -350,6 +350,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Whether to mark failed scans as clean or not",
)
agentgenerator_host: str = Field(
default="",
description="The host for the Agent Generator service (empty to use built-in)",
)
agentgenerator_port: int = Field(
default=8000,
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=120,
description="The timeout in seconds for Agent Generator service requests",
)
enable_example_blocks: bool = Field(
default=False,
description="Whether to enable example blocks in production",

View File

@@ -1,3 +1,4 @@
import asyncio
import inspect
import logging
import time
@@ -58,6 +59,11 @@ class SpinTestServer:
self.db_api.__exit__(exc_type, exc_val, exc_tb)
self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
# Give services time to fully shut down
# This prevents event loop issues where services haven't fully cleaned up
# before the next test starts
await asyncio.sleep(0.5)
def setup_dependency_overrides(self):
# Override get_user_id for testing
self.agent_server.set_test_dependency_overrides(

View File

@@ -1,81 +0,0 @@
-- CreateEnum
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
-- CreateTable
CREATE TABLE "LlmProvider" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"defaultCredentialProvider" TEXT,
"defaultCredentialId" TEXT,
"defaultCredentialType" TEXT,
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
);
-- CreateTable
CREATE TABLE "LlmModel" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"slug" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"providerId" TEXT NOT NULL,
"contextWindow" INTEGER NOT NULL,
"maxOutputTokens" INTEGER,
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
);
-- CreateTable
CREATE TABLE "LlmModelCost" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
"creditCost" INTEGER NOT NULL,
"credentialProvider" TEXT NOT NULL,
"credentialId" TEXT,
"credentialType" TEXT,
"currency" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
"llmModelId" TEXT NOT NULL,
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
-- CreateIndex
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
-- CreateIndex
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
-- CreateIndex
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModelCost_llmModelId_credentialProvider_unit_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit");
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -1,226 +0,0 @@
-- Seed LLM Registry from existing hard-coded data
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
-- Insert Providers
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
VALUES
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
ON CONFLICT ("name") DO NOTHING;
-- Insert Models (using CTEs to reference provider IDs)
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
model_slug,
model_display_name,
NULL,
p."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb
FROM (VALUES
-- OpenAI models
('o3', 'O3', 'openai', 200000, 100000),
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
('o1', 'O1', 'openai', 200000, 100000),
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1000000, 32768),
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
-- Anthropic models
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
-- AI/ML API models
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
-- Groq models
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
-- Ollama models
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
('llama3', 'Llama 3', 'ollama', 8192, NULL),
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
-- OpenRouter models
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
-- Llama API models
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
-- v0 models
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
ON CONFLICT ("slug") DO NOTHING;
-- Insert Costs (using CTEs to reference model IDs)
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM (VALUES
-- OpenAI costs
('o3', 4),
('o3-mini', 2),
('o1', 16),
('o1-mini', 4),
('gpt-5-2025-08-07', 2),
('gpt-5.1-2025-11-13', 5),
('gpt-5-mini-2025-08-07', 1),
('gpt-5-nano-2025-08-07', 1),
('gpt-5-chat-latest', 5),
('gpt-4.1-2025-04-14', 2),
('gpt-4.1-mini-2025-04-14', 1),
('gpt-4o-mini', 1),
('gpt-4o', 3),
('gpt-4-turbo', 10),
('gpt-3.5-turbo', 1),
-- Anthropic costs
('claude-opus-4-1-20250805', 21),
('claude-opus-4-20250514', 21),
('claude-sonnet-4-20250514', 5),
('claude-haiku-4-5-20251001', 4),
('claude-opus-4-5-20251101', 14),
('claude-sonnet-4-5-20250929', 9),
('claude-3-7-sonnet-20250219', 5),
('claude-3-haiku-20240307', 1),
-- AI/ML API costs
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
-- Groq costs
('llama-3.3-70b-versatile', 1),
('llama-3.1-8b-instant', 1),
-- Ollama costs
('llama3.3', 1),
('llama3.2', 1),
('llama3', 1),
('llama3.1:405b', 1),
('dolphin-mistral:latest', 1),
-- OpenRouter costs
('google/gemini-2.5-pro-preview-03-25', 4),
('google/gemini-3-pro-preview', 5),
('mistralai/mistral-nemo', 1),
('cohere/command-r-08-2024', 1),
('cohere/command-r-plus-08-2024', 3),
('deepseek/deepseek-chat', 2),
('perplexity/sonar', 1),
('perplexity/sonar-pro', 5),
('perplexity/sonar-deep-research', 10),
('nousresearch/hermes-3-llama-3.1-405b', 1),
('nousresearch/hermes-3-llama-3.1-70b', 1),
('amazon/nova-lite-v1', 1),
('amazon/nova-micro-v1', 1),
('amazon/nova-pro-v1', 1),
('microsoft/wizardlm-2-8x22b', 1),
('gryphe/mythomax-l2-13b', 1),
('meta-llama/llama-4-scout', 1),
('meta-llama/llama-4-maverick', 1),
('x-ai/grok-4', 9),
('x-ai/grok-4-fast', 1),
('x-ai/grok-4.1-fast', 1),
('x-ai/grok-code-fast-1', 1),
('moonshotai/kimi-k2', 1),
('qwen/qwen3-235b-a22b-thinking-2507', 1),
('qwen/qwen3-coder', 9),
('google/gemini-2.5-flash', 1),
('google/gemini-2.0-flash-001', 1),
('google/gemini-2.5-flash-lite-preview-06-17', 1),
('google/gemini-2.0-flash-lite-001', 1),
('deepseek/deepseek-r1-0528', 1),
('openai/gpt-oss-120b', 1),
('openai/gpt-oss-20b', 1),
-- Llama API costs
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
('Llama-3.3-8B-Instruct', 1),
('Llama-3.3-70B-Instruct', 1),
-- v0 costs
('v0-1.5-md', 1),
('v0-1.5-lg', 2),
('v0-1.0-md', 1)
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId"
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;

View File

@@ -1,25 +0,0 @@
-- CreateTable
CREATE TABLE "LlmModelMigration" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"sourceModelSlug" TEXT NOT NULL,
"targetModelSlug" TEXT NOT NULL,
"reason" TEXT,
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
"nodeCount" INTEGER NOT NULL,
"customCreditCost" INTEGER,
"isReverted" BOOLEAN NOT NULL DEFAULT false,
"revertedAt" TIMESTAMP(3),
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_idx" ON "LlmModelMigration"("sourceModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_isReverted_idx" ON "LlmModelMigration"("isReverted");

View File

@@ -1,127 +0,0 @@
-- Add LlmModelCreator table
-- Creator represents who made/trained the model (e.g., OpenAI, Meta)
-- This is distinct from Provider who hosts/serves the model (e.g., OpenRouter)
-- Create the LlmModelCreator table
CREATE TABLE "LlmModelCreator" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"websiteUrl" TEXT,
"logoUrl" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
);
-- Create unique index on name
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
-- Add creatorId column to LlmModel
ALTER TABLE "LlmModel" ADD COLUMN "creatorId" TEXT;
-- Add foreign key constraint
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- Create index on creatorId
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- Seed creators based on known model creators
INSERT INTO "LlmModelCreator" ("id", "updatedAt", "name", "displayName", "description", "websiteUrl", "metadata")
VALUES
(gen_random_uuid(), CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT models', 'https://openai.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude models', 'https://anthropic.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama models', 'https://ai.meta.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini models', 'https://deepmind.google', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'mistral', 'Mistral AI', 'Creator of Mistral models', 'https://mistral.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command models', 'https://cohere.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek models', 'https://deepseek.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar models', 'https://perplexity.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'qwen', 'Qwen (Alibaba)', 'Creator of Qwen models', 'https://qwenlm.github.io', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova models', 'https://aws.amazon.com/bedrock', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of WizardLM models', 'https://microsoft.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'moonshot', 'Moonshot AI', 'Creator of Kimi models', 'https://moonshot.cn', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nous_research', 'Nous Research', 'Creator of Hermes models', 'https://nousresearch.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 models', 'https://vercel.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cognitive_computations', 'Cognitive Computations', 'Creator of Dolphin models', 'https://erichartford.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', '{}')
ON CONFLICT ("name") DO NOTHING;
-- Update existing models with their creators
-- OpenAI models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'openai')
WHERE "slug" LIKE 'gpt-%' OR "slug" LIKE 'o1%' OR "slug" LIKE 'o3%' OR "slug" LIKE 'openai/%';
-- Anthropic models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'anthropic')
WHERE "slug" LIKE 'claude-%';
-- Meta/Llama models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'meta')
WHERE "slug" LIKE 'llama%' OR "slug" LIKE 'Llama%' OR "slug" LIKE 'meta-llama/%' OR "slug" LIKE '%/llama-%';
-- Google models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'google')
WHERE "slug" LIKE 'google/%' OR "slug" LIKE 'gemini%';
-- Mistral models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'mistral')
WHERE "slug" LIKE 'mistral%' OR "slug" LIKE 'mistralai/%';
-- Cohere models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cohere')
WHERE "slug" LIKE 'cohere/%' OR "slug" LIKE 'command-%';
-- DeepSeek models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'deepseek')
WHERE "slug" LIKE 'deepseek/%' OR "slug" LIKE 'deepseek-%';
-- Perplexity models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'perplexity')
WHERE "slug" LIKE 'perplexity/%' OR "slug" LIKE 'sonar%';
-- Qwen models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'qwen')
WHERE "slug" LIKE 'Qwen/%' OR "slug" LIKE 'qwen/%';
-- xAI/Grok models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'xai')
WHERE "slug" LIKE 'x-ai/%' OR "slug" LIKE 'grok%';
-- Amazon models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'amazon')
WHERE "slug" LIKE 'amazon/%' OR "slug" LIKE 'nova-%';
-- Microsoft models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'microsoft')
WHERE "slug" LIKE 'microsoft/%' OR "slug" LIKE 'wizardlm%';
-- Moonshot models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'moonshot')
WHERE "slug" LIKE 'moonshotai/%' OR "slug" LIKE 'kimi%';
-- NVIDIA models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nvidia')
WHERE "slug" LIKE 'nvidia/%' OR "slug" LIKE '%nemotron%';
-- Nous Research models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nous_research')
WHERE "slug" LIKE 'nousresearch/%' OR "slug" LIKE 'hermes%';
-- Vercel/v0 models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'vercel')
WHERE "slug" LIKE 'v0-%';
-- Dolphin models (Cognitive Computations / Eric Hartford)
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cognitive_computations')
WHERE "slug" LIKE 'dolphin-%';
-- Gryphe models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'gryphe')
WHERE "slug" LIKE 'gryphe/%' OR "slug" LIKE 'mythomax%';

View File

@@ -1,4 +0,0 @@
-- CreateIndex
-- Index for efficient LLM model lookups on AgentNode.constantInput->>'model'
-- This improves performance of model migration queries in the LLM registry
CREATE INDEX "AgentNode_constantInput_model_idx" ON "AgentNode" ((("constantInput"->>'model')));

View File

@@ -1,52 +0,0 @@
-- Add GPT-5.2 model and update O3 slug
-- This migration adds the new GPT-5.2 model added in dev branch
-- Update O3 slug to match dev branch format
UPDATE "LlmModel"
SET "slug" = 'o3-2025-04-16'
WHERE "slug" = 'o3';
-- Update cost reference for O3 if needed
-- (costs are linked by model ID, so no update needed)
-- Add GPT-5.2 model
WITH provider_id AS (
SELECT "id" FROM "LlmProvider" WHERE "name" = 'openai'
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
'gpt-5.2-2025-12-11',
'GPT 5.2',
'OpenAI GPT-5.2 model',
p."id",
400000,
128000,
true,
'{}'::jsonb,
'{}'::jsonb
FROM provider_id p
ON CONFLICT ("slug") DO NOTHING;
-- Add cost for GPT-5.2
WITH model_id AS (
SELECT m."id", p."name" as provider_name
FROM "LlmModel" m
JOIN "LlmProvider" p ON p."id" = m."providerId"
WHERE m."slug" = 'gpt-5.2-2025-12-11'
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
3, -- Same cost tier as GPT-5.1
m.provider_name,
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM model_id m
WHERE NOT EXISTS (
SELECT 1 FROM "LlmModelCost" c WHERE c."llmModelId" = m."id"
);

View File

@@ -1,11 +0,0 @@
-- Add isRecommended field to LlmModel table
-- This allows admins to mark a model as the recommended default
ALTER TABLE "LlmModel" ADD COLUMN "isRecommended" BOOLEAN NOT NULL DEFAULT false;
-- Set gpt-4o-mini as the default recommended model (if it exists)
UPDATE "LlmModel" SET "isRecommended" = true WHERE "slug" = 'gpt-4o-mini' AND "isEnabled" = true;
-- Create unique partial index to enforce only one recommended model at the database level
-- This prevents multiple rows from having isRecommended = true
CREATE UNIQUE INDEX "LlmModel_single_recommended_idx" ON "LlmModel" ("isRecommended") WHERE "isRecommended" = true;

View File

@@ -0,0 +1,7 @@
-- Remove NodeExecution foreign key from PendingHumanReview
-- The nodeExecId column remains as the primary key, but we remove the FK constraint
-- to AgentNodeExecution since PendingHumanReview records can persist after node
-- execution records are deleted.
-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey";

View File

@@ -1,61 +0,0 @@
-- Add new columns to LlmModel table for extended model metadata
-- These columns support the LLM Picker UI enhancements
-- Add priceTier column: 1=cheapest, 2=medium, 3=expensive
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "priceTier" INTEGER NOT NULL DEFAULT 1;
-- Add creatorId column for model creator relationship (if not exists)
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "creatorId" TEXT;
-- Add isRecommended column (if not exists)
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "isRecommended" BOOLEAN NOT NULL DEFAULT FALSE;
-- Add index on creatorId if not exists
CREATE INDEX IF NOT EXISTS "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- Add foreign key for creatorId if not exists
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'LlmModel_creatorId_fkey') THEN
-- Only add FK if LlmModelCreator table exists
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'LlmModelCreator') THEN
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
END IF;
END IF;
END $$;
-- Update priceTier values for existing models based on original MODEL_METADATA
-- Tier 1 = cheapest, Tier 2 = medium, Tier 3 = expensive
-- OpenAI models
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o3';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'o3-mini';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'o1';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o1-mini';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-5.2';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5.1';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-mini';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-nano';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5-chat-latest';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" LIKE 'gpt-4.1%';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-4o-mini';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-4o';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-4-turbo';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-3.5-turbo';
-- Anthropic models
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude-opus%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude-sonnet%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude%-4-5-sonnet%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude%-haiku%';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'claude-3-haiku-20240307';
-- OpenRouter models - Pro/expensive tiers
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'google/gemini%-pro%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%command-r-plus%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%sonar-pro%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%sonar-deep-research%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'x-ai/grok-4';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%qwen3-coder%';

View File

@@ -517,8 +517,6 @@ model AgentNodeExecution {
stats Json?
PendingHumanReview PendingHumanReview?
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
@@index([agentNodeId, executionStatus])
@@index([addedTime, queuedTime])
@@ -567,6 +565,7 @@ enum ReviewStatus {
}
// Pending human reviews for Human-in-the-loop blocks
// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}")
model PendingHumanReview {
nodeExecId String @id
userId String
@@ -585,7 +584,6 @@ model PendingHumanReview {
reviewedAt DateTime?
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade)
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
@@unique([nodeExecId]) // One pending review per node execution
@@ -1096,153 +1094,6 @@ enum APIKeyStatus {
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
///////////// LLM REGISTRY AND BILLING DATA /////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// LlmCostUnit: Defines how LLM MODEL costs are calculated (per run or per token).
// This is distinct from BlockCostType (in backend/data/block.py) which defines
// how BLOCK EXECUTION costs are calculated (per run, per byte, or per second).
// LlmCostUnit is for pricing individual LLM model API calls in the registry,
// while BlockCostType is for billing platform block executions.
enum LlmCostUnit {
RUN
TOKENS
}
model LlmModelCreator {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique // e.g., "openai", "anthropic", "meta"
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
description String?
websiteUrl String? // Link to creator's website
logoUrl String? // URL to creator's logo
metadata Json @default("{}")
Models LlmModel[]
}
model LlmProvider {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique
displayName String
description String?
defaultCredentialProvider String?
defaultCredentialId String?
defaultCredentialType String?
supportsTools Boolean @default(true)
supportsJsonOutput Boolean @default(true)
supportsReasoning Boolean @default(false)
supportsParallelTool Boolean @default(false)
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModel {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
slug String @unique
displayName String
description String?
providerId String
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
creatorId String?
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
contextWindow Int
maxOutputTokens Int?
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive
isEnabled Boolean @default(true)
isRecommended Boolean @default(false)
capabilities Json @default("{}")
metadata Json @default("{}")
Costs LlmModelCost[]
@@index([providerId, isEnabled])
@@index([creatorId])
@@index([slug])
}
model LlmModelCost {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
unit LlmCostUnit @default(RUN)
creditCost Int
credentialProvider String
credentialId String?
credentialType String?
currency String?
metadata Json @default("{}")
llmModelId String
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
@@unique([llmModelId, credentialProvider, unit])
@@index([llmModelId])
@@index([credentialProvider])
}
// Tracks model migrations for revert capability
// When a model is disabled with migration, we record which nodes were affected
// so they can be reverted when the original model is back online
model LlmModelMigration {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
sourceModelSlug String // The original model that was disabled
targetModelSlug String // The model workflows were migrated to
reason String? // Why the migration happened (e.g., "Provider outage")
// Track affected nodes as JSON array of node IDs
// Format: ["node-uuid-1", "node-uuid-2", ...]
migratedNodeIds Json @default("[]")
nodeCount Int // Number of nodes migrated
// Custom pricing override for migrated workflows during the migration period.
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
// to avoid billing surprises, or offer a discount during the transition.
//
// IMPORTANT: This field is intended for integration with the billing system.
// When billing calculates costs for nodes affected by this migration, it should
// check if customCreditCost is set and use it instead of the target model's cost.
// If null, the target model's normal cost applies.
//
// TODO: Integrate with billing system to apply this override during cost calculation.
customCreditCost Int?
// Revert tracking
isReverted Boolean @default(false)
revertedAt DateTime?
@@index([sourceModelSlug])
@@index([targetModelSlug])
@@index([isReverted])
}
////////////// OAUTH PROVIDER TABLES //////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////

View File

@@ -34,7 +34,10 @@ logger = logging.getLogger(__name__)
# Default output directory relative to repo root
DEFAULT_OUTPUT_DIR = (
Path(__file__).parent.parent.parent.parent / "docs" / "integrations"
Path(__file__).parent.parent.parent.parent
/ "docs"
/ "integrations"
/ "block-integrations"
)
@@ -421,6 +424,14 @@ def generate_block_markdown(
lines.append("<!-- END MANUAL -->")
lines.append("")
# Optional per-block extras (only include if has content)
extras = manual_content.get("extras", "")
if extras:
lines.append("<!-- MANUAL: extras -->")
lines.append(extras)
lines.append("<!-- END MANUAL -->")
lines.append("")
lines.append("---")
lines.append("")
@@ -456,25 +467,52 @@ def get_block_file_mapping(blocks: list[BlockDoc]) -> dict[str, list[BlockDoc]]:
return dict(file_mapping)
def generate_overview_table(blocks: list[BlockDoc]) -> str:
"""Generate the overview table markdown (blocks.md)."""
def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "") -> str:
"""Generate the overview table markdown (blocks.md).
Args:
blocks: List of block documentation objects
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
"""
lines = []
# GitBook YAML frontmatter
lines.append("---")
lines.append("layout:")
lines.append(" width: default")
lines.append(" title:")
lines.append(" visible: true")
lines.append(" description:")
lines.append(" visible: true")
lines.append(" tableOfContents:")
lines.append(" visible: false")
lines.append(" outline:")
lines.append(" visible: true")
lines.append(" pagination:")
lines.append(" visible: true")
lines.append(" metadata:")
lines.append(" visible: true")
lines.append("---")
lines.append("")
lines.append("# AutoGPT Blocks Overview")
lines.append("")
lines.append(
'AutoGPT uses a modular approach with various "blocks" to handle different tasks. These blocks are the building blocks of AutoGPT workflows, allowing users to create complex automations by combining simple, specialized components.'
)
lines.append("")
lines.append('!!! info "Creating Your Own Blocks"')
lines.append(" Want to create your own custom blocks? Check out our guides:")
lines.append(" ")
lines.append('{% hint style="info" %}')
lines.append("**Creating Your Own Blocks**")
lines.append("")
lines.append("Want to create your own custom blocks? Check out our guides:")
lines.append("")
lines.append(
" - [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
"* [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
)
lines.append(
" - [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
"* [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
)
lines.append("{% endhint %}")
lines.append("")
lines.append(
"Below is a comprehensive list of all available blocks, categorized by their primary function. Click on any block name to view its detailed documentation."
@@ -537,7 +575,8 @@ def generate_overview_table(blocks: list[BlockDoc]) -> str:
else "No description"
)
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
link_path = f"{block_dir_prefix}{file_path}"
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
lines.append("")
continue
@@ -563,13 +602,55 @@ def generate_overview_table(blocks: list[BlockDoc]) -> str:
)
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
link_path = f"{block_dir_prefix}{file_path}"
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
lines.append("")
return "\n".join(lines)
def generate_summary_md(
blocks: list[BlockDoc], root_dir: Path, block_dir_prefix: str = ""
) -> str:
"""Generate SUMMARY.md for GitBook navigation.
Args:
blocks: List of block documentation objects
root_dir: The root docs directory (e.g., docs/integrations/)
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
"""
lines = []
lines.append("# Table of contents")
lines.append("")
lines.append("* [AutoGPT Blocks Overview](README.md)")
lines.append("")
# Check for guides/ directory at the root level (docs/integrations/guides/)
guides_dir = root_dir / "guides"
if guides_dir.exists():
lines.append("## Guides")
lines.append("")
for guide_file in sorted(guides_dir.glob("*.md")):
# Use just the file name for title (replace hyphens/underscores with spaces)
title = file_path_to_title(guide_file.stem.replace("-", "_") + ".md")
lines.append(f"* [{title}](guides/{guide_file.name})")
lines.append("")
lines.append("## Block Integrations")
lines.append("")
file_mapping = get_block_file_mapping(blocks)
for file_path in sorted(file_mapping.keys()):
title = file_path_to_title(file_path)
link_path = f"{block_dir_prefix}{file_path}"
lines.append(f"* [{title}]({link_path})")
lines.append("")
return "\n".join(lines)
def load_all_blocks_for_docs() -> list[BlockDoc]:
"""Load all blocks and extract documentation."""
from backend.blocks import load_all_blocks
@@ -653,6 +734,16 @@ def write_block_docs(
)
)
# Add file-level additional_content section if present
file_additional = extract_manual_content(existing_content).get(
"additional_content", ""
)
if file_additional:
content_parts.append("<!-- MANUAL: additional_content -->")
content_parts.append(file_additional)
content_parts.append("<!-- END MANUAL -->")
content_parts.append("")
full_content = file_header + "\n" + "\n".join(content_parts)
generated_files[str(file_path)] = full_content
@@ -661,14 +752,28 @@ def write_block_docs(
full_path.write_text(full_content)
# Generate overview file
overview_content = generate_overview_table(blocks)
overview_path = output_dir / "README.md"
# Generate overview file at the parent directory (docs/integrations/)
# with links prefixed to point into block-integrations/
root_dir = output_dir.parent
block_dir_name = output_dir.name # "block-integrations"
block_dir_prefix = f"{block_dir_name}/"
overview_content = generate_overview_table(blocks, block_dir_prefix)
overview_path = root_dir / "README.md"
generated_files["README.md"] = overview_content
overview_path.write_text(overview_content)
if verbose:
print(" Writing README.md (overview)")
print(" Writing README.md (overview) to parent directory")
# Generate SUMMARY.md for GitBook navigation at the parent directory
summary_content = generate_summary_md(blocks, root_dir, block_dir_prefix)
summary_path = root_dir / "SUMMARY.md"
generated_files["SUMMARY.md"] = summary_content
summary_path.write_text(summary_content)
if verbose:
print(" Writing SUMMARY.md (navigation) to parent directory")
return generated_files
@@ -748,6 +853,16 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
elif block_match.group(1).strip() != expected_block_content.strip():
mismatched_blocks.append(block.name)
# Add file-level additional_content to expected content (matches write_block_docs)
file_additional = extract_manual_content(existing_content).get(
"additional_content", ""
)
if file_additional:
content_parts.append("<!-- MANUAL: additional_content -->")
content_parts.append(file_additional)
content_parts.append("<!-- END MANUAL -->")
content_parts.append("")
expected_content = file_header + "\n" + "\n".join(content_parts)
if existing_content.strip() != expected_content.strip():
@@ -757,11 +872,15 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
out_of_sync_details.append((file_path, mismatched_blocks))
all_match = False
# Check overview
overview_path = output_dir / "README.md"
# Check overview at the parent directory (docs/integrations/)
root_dir = output_dir.parent
block_dir_name = output_dir.name # "block-integrations"
block_dir_prefix = f"{block_dir_name}/"
overview_path = root_dir / "README.md"
if overview_path.exists():
existing_overview = overview_path.read_text()
expected_overview = generate_overview_table(blocks)
expected_overview = generate_overview_table(blocks, block_dir_prefix)
if existing_overview.strip() != expected_overview.strip():
print("OUT OF SYNC: README.md (overview)")
print(" The blocks overview table needs regeneration")
@@ -772,6 +891,21 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
out_of_sync_details.append(("README.md", ["overview table"]))
all_match = False
# Check SUMMARY.md at the parent directory
summary_path = root_dir / "SUMMARY.md"
if summary_path.exists():
existing_summary = summary_path.read_text()
expected_summary = generate_summary_md(blocks, root_dir, block_dir_prefix)
if existing_summary.strip() != expected_summary.strip():
print("OUT OF SYNC: SUMMARY.md (navigation)")
print(" The GitBook navigation needs regeneration")
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
all_match = False
else:
print("MISSING: SUMMARY.md (navigation)")
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
all_match = False
# Check for unfilled manual sections
unfilled_patterns = [
"_Add a description of this category of blocks._",

View File

@@ -0,0 +1 @@
"""Tests for agent generator module."""

View File

@@ -0,0 +1,273 @@
"""
Tests for the Agent Generator core module.
This test suite verifies that the core functions correctly delegate to
the external Agent Generator service.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.api.features.chat.tools.agent_generator import core
from backend.api.features.chat.tools.agent_generator.core import (
AgentGeneratorNotConfiguredError,
)
class TestServiceNotConfigured:
"""Test that functions raise AgentGeneratorNotConfiguredError when service is not configured."""
@pytest.mark.asyncio
async def test_decompose_goal_raises_when_not_configured(self):
"""Test that decompose_goal raises error when service not configured."""
with patch.object(core, "is_external_service_configured", return_value=False):
with pytest.raises(AgentGeneratorNotConfiguredError):
await core.decompose_goal("Build a chatbot")
@pytest.mark.asyncio
async def test_generate_agent_raises_when_not_configured(self):
"""Test that generate_agent raises error when service not configured."""
with patch.object(core, "is_external_service_configured", return_value=False):
with pytest.raises(AgentGeneratorNotConfiguredError):
await core.generate_agent({"steps": []})
@pytest.mark.asyncio
async def test_generate_agent_patch_raises_when_not_configured(self):
"""Test that generate_agent_patch raises error when service not configured."""
with patch.object(core, "is_external_service_configured", return_value=False):
with pytest.raises(AgentGeneratorNotConfiguredError):
await core.generate_agent_patch("Add a node", {"nodes": []})
class TestDecomposeGoal:
"""Test decompose_goal function service delegation."""
@pytest.mark.asyncio
async def test_calls_external_service(self):
"""Test that decompose_goal calls the external service."""
expected_result = {"type": "instructions", "steps": ["Step 1"]}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "decompose_goal_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
result = await core.decompose_goal("Build a chatbot")
mock_external.assert_called_once_with("Build a chatbot", "")
assert result == expected_result
@pytest.mark.asyncio
async def test_passes_context_to_external_service(self):
"""Test that decompose_goal passes context to external service."""
expected_result = {"type": "instructions", "steps": ["Step 1"]}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "decompose_goal_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
await core.decompose_goal("Build a chatbot", "Use Python")
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
@pytest.mark.asyncio
async def test_returns_none_on_service_failure(self):
"""Test that decompose_goal returns None when external service fails."""
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "decompose_goal_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = None
result = await core.decompose_goal("Build a chatbot")
assert result is None
class TestGenerateAgent:
"""Test generate_agent function service delegation."""
@pytest.mark.asyncio
async def test_calls_external_service(self):
"""Test that generate_agent calls the external service."""
expected_result = {"name": "Test Agent", "nodes": [], "links": []}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await core.generate_agent(instructions)
mock_external.assert_called_once_with(instructions)
# Result should have id, version, is_active added if not present
assert result is not None
assert result["name"] == "Test Agent"
assert "id" in result
assert result["version"] == 1
assert result["is_active"] is True
@pytest.mark.asyncio
async def test_preserves_existing_id_and_version(self):
"""Test that external service result preserves existing id and version."""
expected_result = {
"id": "existing-id",
"version": 3,
"is_active": False,
"name": "Test Agent",
}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result.copy()
result = await core.generate_agent({"steps": []})
assert result is not None
assert result["id"] == "existing-id"
assert result["version"] == 3
assert result["is_active"] is False
@pytest.mark.asyncio
async def test_returns_none_when_external_service_fails(self):
"""Test that generate_agent returns None when external service fails."""
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = None
result = await core.generate_agent({"steps": []})
assert result is None
class TestGenerateAgentPatch:
"""Test generate_agent_patch function service delegation."""
@pytest.mark.asyncio
async def test_calls_external_service(self):
"""Test that generate_agent_patch calls the external service."""
expected_result = {"name": "Updated Agent", "nodes": [], "links": []}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_patch_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
current_agent = {"nodes": [], "links": []}
result = await core.generate_agent_patch("Add a node", current_agent)
mock_external.assert_called_once_with("Add a node", current_agent)
assert result == expected_result
@pytest.mark.asyncio
async def test_returns_clarifying_questions(self):
"""Test that generate_agent_patch returns clarifying questions."""
expected_result = {
"type": "clarifying_questions",
"questions": [{"question": "What type of node?"}],
}
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_patch_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = expected_result
result = await core.generate_agent_patch("Add a node", {"nodes": []})
assert result == expected_result
@pytest.mark.asyncio
async def test_returns_none_when_external_service_fails(self):
"""Test that generate_agent_patch returns None when service fails."""
with patch.object(
core, "is_external_service_configured", return_value=True
), patch.object(
core, "generate_agent_patch_external", new_callable=AsyncMock
) as mock_external:
mock_external.return_value = None
result = await core.generate_agent_patch("Add a node", {"nodes": []})
assert result is None
class TestJsonToGraph:
"""Test json_to_graph function."""
def test_converts_agent_json_to_graph(self):
"""Test conversion of agent JSON to Graph model."""
agent_json = {
"id": "test-id",
"version": 2,
"is_active": True,
"name": "Test Agent",
"description": "A test agent",
"nodes": [
{
"id": "node1",
"block_id": "block1",
"input_default": {"key": "value"},
"metadata": {"x": 100},
}
],
"links": [
{
"id": "link1",
"source_id": "node1",
"sink_id": "output",
"source_name": "result",
"sink_name": "input",
"is_static": False,
}
],
}
graph = core.json_to_graph(agent_json)
assert graph.id == "test-id"
assert graph.version == 2
assert graph.is_active is True
assert graph.name == "Test Agent"
assert graph.description == "A test agent"
assert len(graph.nodes) == 1
assert graph.nodes[0].id == "node1"
assert graph.nodes[0].block_id == "block1"
assert len(graph.links) == 1
assert graph.links[0].source_id == "node1"
def test_generates_ids_if_missing(self):
"""Test that missing IDs are generated."""
agent_json = {
"name": "Test Agent",
"nodes": [{"block_id": "block1"}],
"links": [],
}
graph = core.json_to_graph(agent_json)
assert graph.id is not None
assert graph.nodes[0].id is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,422 @@
"""
Tests for the Agent Generator external service client.
This test suite verifies the external Agent Generator service integration,
including service detection, API calls, and error handling.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from backend.api.features.chat.tools.agent_generator import service
class TestServiceConfiguration:
"""Test service configuration detection."""
def setup_method(self):
"""Reset settings singleton before each test."""
service._settings = None
service._client = None
def test_external_service_not_configured_when_host_empty(self):
"""Test that external service is not configured when host is empty."""
mock_settings = MagicMock()
mock_settings.config.agentgenerator_host = ""
with patch.object(service, "_get_settings", return_value=mock_settings):
assert service.is_external_service_configured() is False
def test_external_service_configured_when_host_set(self):
"""Test that external service is configured when host is set."""
mock_settings = MagicMock()
mock_settings.config.agentgenerator_host = "agent-generator.local"
with patch.object(service, "_get_settings", return_value=mock_settings):
assert service.is_external_service_configured() is True
def test_get_base_url(self):
"""Test base URL construction."""
mock_settings = MagicMock()
mock_settings.config.agentgenerator_host = "agent-generator.local"
mock_settings.config.agentgenerator_port = 8000
with patch.object(service, "_get_settings", return_value=mock_settings):
url = service._get_base_url()
assert url == "http://agent-generator.local:8000"
class TestDecomposeGoalExternal:
"""Test decompose_goal_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_decompose_goal_returns_instructions(self):
"""Test successful decomposition returning instructions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1", "Step 2"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
mock_client.post.assert_called_once_with(
"/api/decompose-description", json={"description": "Build a chatbot"}
)
@pytest.mark.asyncio
async def test_decompose_goal_returns_clarifying_questions(self):
"""Test decomposition returning clarifying questions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "clarifying_questions",
"questions": ["What platform?", "What language?"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build something")
assert result == {
"type": "clarifying_questions",
"questions": ["What platform?", "What language?"],
}
@pytest.mark.asyncio
async def test_decompose_goal_with_context(self):
"""Test decomposition with additional context."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
await service.decompose_goal_external(
"Build a chatbot", context="Use Python"
)
mock_client.post.assert_called_once_with(
"/api/decompose-description",
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
)
@pytest.mark.asyncio
async def test_decompose_goal_returns_unachievable_goal(self):
"""Test decomposition returning unachievable goal response."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "unachievable_goal",
"reason": "Cannot do X",
"suggested_goal": "Try Y instead",
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Do something impossible")
assert result == {
"type": "unachievable_goal",
"reason": "Cannot do X",
"suggested_goal": "Try Y instead",
}
@pytest.mark.asyncio
async def test_decompose_goal_handles_http_error(self):
"""Test decomposition handles HTTP errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.HTTPStatusError(
"Server error", request=MagicMock(), response=MagicMock()
)
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result is None
@pytest.mark.asyncio
async def test_decompose_goal_handles_request_error(self):
"""Test decomposition handles request errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result is None
@pytest.mark.asyncio
async def test_decompose_goal_handles_service_error(self):
"""Test decomposition handles service returning error."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": False,
"error": "Internal error",
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.decompose_goal_external("Build a chatbot")
assert result is None
class TestGenerateAgentExternal:
"""Test generate_agent_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_generate_agent_success(self):
"""Test successful agent generation."""
agent_json = {
"name": "Test Agent",
"nodes": [],
"links": [],
}
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": agent_json,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
instructions = {"type": "instructions", "steps": ["Step 1"]}
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_external(instructions)
assert result == agent_json
mock_client.post.assert_called_once_with(
"/api/generate-agent", json={"instructions": instructions}
)
@pytest.mark.asyncio
async def test_generate_agent_handles_error(self):
"""Test agent generation handles errors gracefully."""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_external({"steps": []})
assert result is None
class TestGenerateAgentPatchExternal:
"""Test generate_agent_patch_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_generate_patch_returns_updated_agent(self):
"""Test successful patch generation returning updated agent."""
updated_agent = {
"name": "Updated Agent",
"nodes": [{"id": "1", "block_id": "test"}],
"links": [],
}
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": updated_agent,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_patch_external(
"Add a new node", current_agent
)
assert result == updated_agent
mock_client.post.assert_called_once_with(
"/api/update-agent",
json={
"update_request": "Add a new node",
"current_agent_json": current_agent,
},
)
@pytest.mark.asyncio
async def test_generate_patch_returns_clarifying_questions(self):
"""Test patch generation returning clarifying questions."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "clarifying_questions",
"questions": ["What type of node?"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.generate_agent_patch_external(
"Add something", {"nodes": []}
)
assert result == {
"type": "clarifying_questions",
"questions": ["What type of node?"],
}
class TestHealthCheck:
"""Test health_check function."""
def setup_method(self):
"""Reset singletons before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_health_check_returns_false_when_not_configured(self):
"""Test health check returns False when service not configured."""
with patch.object(
service, "is_external_service_configured", return_value=False
):
result = await service.health_check()
assert result is False
@pytest.mark.asyncio
async def test_health_check_returns_true_when_healthy(self):
"""Test health check returns True when service is healthy."""
mock_response = MagicMock()
mock_response.json.return_value = {
"status": "healthy",
"blocks_loaded": True,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
assert result is True
mock_client.get.assert_called_once_with("/health")
@pytest.mark.asyncio
async def test_health_check_returns_false_when_not_healthy(self):
"""Test health check returns False when service is not healthy."""
mock_response = MagicMock()
mock_response.json.return_value = {
"status": "unhealthy",
"blocks_loaded": False,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
assert result is False
@pytest.mark.asyncio
async def test_health_check_returns_false_on_error(self):
"""Test health check returns False on connection error."""
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "is_external_service_configured", return_value=True):
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.health_check()
assert result is False
class TestGetBlocksExternal:
"""Test get_blocks_external function."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_get_blocks_success(self):
"""Test successful blocks retrieval."""
blocks = [
{"id": "block1", "name": "Block 1"},
{"id": "block2", "name": "Block 2"},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"blocks": blocks,
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.get_blocks_external()
assert result == blocks
mock_client.get.assert_called_once_with("/api/blocks")
@pytest.mark.asyncio
async def test_get_blocks_handles_error(self):
"""Test blocks retrieval handles errors gracefully."""
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.RequestError("Connection failed")
with patch.object(service, "_get_client", return_value=mock_client):
result = await service.get_blocks_external()
assert result is None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,8 +1,5 @@
"use client";
import { Sidebar } from "@/components/__legacy__/Sidebar";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import { Cpu } from "@phosphor-icons/react";
import { IconSliders } from "@/components/__legacy__/ui/icons";
@@ -29,11 +26,6 @@ const sidebarLinkGroups = [
href: "/admin/execution-analytics",
icon: <FileText className="h-6 w-6" />,
},
{
text: "LLM Registry",
href: "/admin/llms",
icon: <Cpu size={24} />,
},
{
text: "Admin User Management",
href: "/admin/settings",

View File

@@ -1,493 +0,0 @@
"use server";
import { revalidatePath } from "next/cache";
// Generated API functions
import {
getV2ListLlmProviders,
postV2CreateLlmProvider,
patchV2UpdateLlmProvider,
deleteV2DeleteLlmProvider,
getV2ListLlmModels,
postV2CreateLlmModel,
patchV2UpdateLlmModel,
patchV2ToggleLlmModelAvailability,
deleteV2DeleteLlmModelAndMigrateWorkflows,
getV2GetModelUsageCount,
getV2ListModelMigrations,
postV2RevertAModelMigration,
getV2ListModelCreators,
postV2CreateModelCreator,
patchV2UpdateModelCreator,
deleteV2DeleteModelCreator,
postV2SetRecommendedModel,
} from "@/app/api/__generated__/endpoints/admin/admin";
// Generated types
import type { LlmProvidersResponse } from "@/app/api/__generated__/models/llmProvidersResponse";
import type { LlmModelsResponse } from "@/app/api/__generated__/models/llmModelsResponse";
import type { UpsertLlmProviderRequest } from "@/app/api/__generated__/models/upsertLlmProviderRequest";
import type { CreateLlmModelRequest } from "@/app/api/__generated__/models/createLlmModelRequest";
import type { UpdateLlmModelRequest } from "@/app/api/__generated__/models/updateLlmModelRequest";
import type { ToggleLlmModelRequest } from "@/app/api/__generated__/models/toggleLlmModelRequest";
import type { LlmMigrationsResponse } from "@/app/api/__generated__/models/llmMigrationsResponse";
import type { LlmCreatorsResponse } from "@/app/api/__generated__/models/llmCreatorsResponse";
import type { UpsertLlmCreatorRequest } from "@/app/api/__generated__/models/upsertLlmCreatorRequest";
import type { LlmModelUsageResponse } from "@/app/api/__generated__/models/llmModelUsageResponse";
import { LlmCostUnit } from "@/app/api/__generated__/models/llmCostUnit";
const ADMIN_LLM_PATH = "/admin/llms";
// =============================================================================
// Utilities
// =============================================================================
/**
* Extracts and validates a required string field from FormData.
* Throws an error if the field is missing or empty.
*/
function getRequiredFormField(
formData: FormData,
fieldName: string,
displayName?: string,
): string {
const raw = formData.get(fieldName);
const value = raw ? String(raw).trim() : "";
if (!value) {
throw new Error(`${displayName || fieldName} is required`);
}
return value;
}
/**
* Extracts and validates a required positive number field from FormData.
* Throws an error if the field is missing, empty, or not a positive number.
*/
function getRequiredPositiveNumber(
formData: FormData,
fieldName: string,
displayName?: string,
): number {
const raw = formData.get(fieldName);
const value = Number(raw);
if (raw === null || raw === "" || !Number.isFinite(value) || value <= 0) {
throw new Error(`${displayName || fieldName} must be a positive number`);
}
return value;
}
/**
* Extracts and validates a required number field from FormData.
* Throws an error if the field is missing, empty, or not a finite number.
*/
function getRequiredNumber(
formData: FormData,
fieldName: string,
displayName?: string,
): number {
const raw = formData.get(fieldName);
const value = Number(raw);
if (raw === null || raw === "" || !Number.isFinite(value)) {
throw new Error(`${displayName || fieldName} is required`);
}
return value;
}
// =============================================================================
// Provider Actions
// =============================================================================
export async function fetchLlmProviders(): Promise<LlmProvidersResponse> {
const response = await getV2ListLlmProviders({ include_models: true });
if (response.status !== 200) {
throw new Error("Failed to fetch LLM providers");
}
return response.data;
}
export async function createLlmProviderAction(formData: FormData) {
const payload: UpsertLlmProviderRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
default_credential_provider: formData.get("default_credential_provider")
? String(formData.get("default_credential_provider")).trim()
: undefined,
default_credential_id: formData.get("default_credential_id")
? String(formData.get("default_credential_id")).trim()
: undefined,
default_credential_type: formData.get("default_credential_type")
? String(formData.get("default_credential_type")).trim()
: "api_key",
supports_tools: formData.getAll("supports_tools").includes("on"),
supports_json_output: formData
.getAll("supports_json_output")
.includes("on"),
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
supports_parallel_tool: formData
.getAll("supports_parallel_tool")
.includes("on"),
metadata: {},
};
const response = await postV2CreateLlmProvider(payload);
if (response.status !== 200) {
throw new Error("Failed to create LLM provider");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmProviderAction(
formData: FormData,
): Promise<void> {
const providerId = getRequiredFormField(
formData,
"provider_id",
"Provider id",
);
const response = await deleteV2DeleteLlmProvider(providerId);
if (response.status !== 200) {
const errorData = response.data as { detail?: string };
throw new Error(errorData?.detail || "Failed to delete provider");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmProviderAction(formData: FormData) {
const providerId = getRequiredFormField(
formData,
"provider_id",
"Provider id",
);
const payload: UpsertLlmProviderRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
default_credential_provider: formData.get("default_credential_provider")
? String(formData.get("default_credential_provider")).trim()
: undefined,
default_credential_id: formData.get("default_credential_id")
? String(formData.get("default_credential_id")).trim()
: undefined,
default_credential_type: formData.get("default_credential_type")
? String(formData.get("default_credential_type")).trim()
: "api_key",
supports_tools: formData.getAll("supports_tools").includes("on"),
supports_json_output: formData
.getAll("supports_json_output")
.includes("on"),
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
supports_parallel_tool: formData
.getAll("supports_parallel_tool")
.includes("on"),
metadata: {},
};
const response = await patchV2UpdateLlmProvider(providerId, payload);
if (response.status !== 200) {
throw new Error("Failed to update LLM provider");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Model Actions
// =============================================================================
export async function fetchLlmModels(): Promise<LlmModelsResponse> {
const response = await getV2ListLlmModels();
if (response.status !== 200) {
throw new Error("Failed to fetch LLM models");
}
return response.data;
}
export async function createLlmModelAction(formData: FormData) {
const providerId = getRequiredFormField(formData, "provider_id", "Provider");
const creatorId = formData.get("creator_id");
const contextWindow = getRequiredPositiveNumber(
formData,
"context_window",
"Context window",
);
const creditCost = getRequiredNumber(formData, "credit_cost", "Credit cost");
// Fetch provider to get default credentials
const providersResponse = await getV2ListLlmProviders({
include_models: false,
});
if (providersResponse.status !== 200) {
throw new Error("Failed to fetch providers");
}
const provider = providersResponse.data.providers.find(
(p) => p.id === providerId,
);
if (!provider) {
throw new Error("Provider not found");
}
const payload: CreateLlmModelRequest = {
slug: String(formData.get("slug") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
provider_id: providerId,
creator_id: creatorId ? String(creatorId) : undefined,
context_window: contextWindow,
max_output_tokens: formData.get("max_output_tokens")
? Number(formData.get("max_output_tokens"))
: undefined,
is_enabled: formData.getAll("is_enabled").includes("on"),
capabilities: {},
metadata: {},
costs: [
{
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
credit_cost: creditCost,
credential_provider:
provider.default_credential_provider || provider.name,
credential_id: provider.default_credential_id || undefined,
credential_type: provider.default_credential_type || "api_key",
metadata: {},
},
],
};
const response = await postV2CreateLlmModel(payload);
if (response.status !== 200) {
throw new Error("Failed to create LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmModelAction(formData: FormData) {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const creatorId = formData.get("creator_id");
const payload: UpdateLlmModelRequest = {
display_name: formData.get("display_name")
? String(formData.get("display_name"))
: undefined,
description: formData.get("description")
? String(formData.get("description"))
: undefined,
provider_id: formData.get("provider_id")
? String(formData.get("provider_id"))
: undefined,
creator_id: creatorId ? String(creatorId) : undefined,
context_window: formData.get("context_window")
? Number(formData.get("context_window"))
: undefined,
max_output_tokens: formData.get("max_output_tokens")
? Number(formData.get("max_output_tokens"))
: undefined,
is_enabled: formData.has("is_enabled")
? formData.getAll("is_enabled").includes("on")
: undefined,
costs: formData.get("credit_cost")
? [
{
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
credit_cost: Number(formData.get("credit_cost")),
credential_provider: String(
formData.get("credential_provider") || "",
).trim(),
credential_id: formData.get("credential_id")
? String(formData.get("credential_id"))
: undefined,
credential_type: formData.get("credential_type")
? String(formData.get("credential_type"))
: undefined,
metadata: {},
},
]
: undefined,
};
const response = await patchV2UpdateLlmModel(modelId, payload);
if (response.status !== 200) {
throw new Error("Failed to update LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function toggleLlmModelAction(formData: FormData): Promise<void> {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const shouldEnable = formData.get("is_enabled") === "true";
const migrateToSlug = formData.get("migrate_to_slug");
const migrationReason = formData.get("migration_reason");
const customCreditCost = formData.get("custom_credit_cost");
const payload: ToggleLlmModelRequest = {
is_enabled: shouldEnable,
migrate_to_slug: migrateToSlug ? String(migrateToSlug) : undefined,
migration_reason: migrationReason ? String(migrationReason) : undefined,
custom_credit_cost: customCreditCost ? Number(customCreditCost) : undefined,
};
const response = await patchV2ToggleLlmModelAvailability(modelId, payload);
if (response.status !== 200) {
throw new Error("Failed to toggle LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmModelAction(formData: FormData): Promise<void> {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const rawReplacement = formData.get("replacement_model_slug");
const replacementModelSlug =
rawReplacement && String(rawReplacement).trim()
? String(rawReplacement).trim()
: undefined;
const response = await deleteV2DeleteLlmModelAndMigrateWorkflows(modelId, {
replacement_model_slug: replacementModelSlug,
});
if (response.status !== 200) {
throw new Error("Failed to delete model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function fetchLlmModelUsage(
modelId: string,
): Promise<LlmModelUsageResponse> {
const response = await getV2GetModelUsageCount(modelId);
if (response.status !== 200) {
throw new Error("Failed to fetch model usage");
}
return response.data;
}
// =============================================================================
// Migration Actions
// =============================================================================
export async function fetchLlmMigrations(
includeReverted: boolean = false,
): Promise<LlmMigrationsResponse> {
const response = await getV2ListModelMigrations({
include_reverted: includeReverted,
});
if (response.status !== 200) {
throw new Error("Failed to fetch migrations");
}
return response.data;
}
export async function revertLlmMigrationAction(
formData: FormData,
): Promise<void> {
const migrationId = getRequiredFormField(
formData,
"migration_id",
"Migration id",
);
const response = await postV2RevertAModelMigration(migrationId, null);
if (response.status !== 200) {
throw new Error("Failed to revert migration");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Creator Actions
// =============================================================================
export async function fetchLlmCreators(): Promise<LlmCreatorsResponse> {
const response = await getV2ListModelCreators();
if (response.status !== 200) {
throw new Error("Failed to fetch creators");
}
return response.data;
}
export async function createLlmCreatorAction(
formData: FormData,
): Promise<void> {
const payload: UpsertLlmCreatorRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
website_url: formData.get("website_url")
? String(formData.get("website_url")).trim()
: undefined,
logo_url: formData.get("logo_url")
? String(formData.get("logo_url")).trim()
: undefined,
metadata: {},
};
const response = await postV2CreateModelCreator(payload);
if (response.status !== 200) {
throw new Error("Failed to create creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmCreatorAction(
formData: FormData,
): Promise<void> {
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
const payload: UpsertLlmCreatorRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
website_url: formData.get("website_url")
? String(formData.get("website_url")).trim()
: undefined,
logo_url: formData.get("logo_url")
? String(formData.get("logo_url")).trim()
: undefined,
metadata: {},
};
const response = await patchV2UpdateModelCreator(creatorId, payload);
if (response.status !== 200) {
throw new Error("Failed to update creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmCreatorAction(
formData: FormData,
): Promise<void> {
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
const response = await deleteV2DeleteModelCreator(creatorId);
if (response.status !== 200) {
throw new Error("Failed to delete creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Recommended Model Actions
// =============================================================================
export async function setRecommendedModelAction(
formData: FormData,
): Promise<void> {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const response = await postV2SetRecommendedModel({ model_id: modelId });
if (response.status !== 200) {
throw new Error("Failed to set recommended model");
}
revalidatePath(ADMIN_LLM_PATH);
}

View File

@@ -1,147 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { createLlmCreatorAction } from "../actions";
import { useRouter } from "next/navigation";
export function AddCreatorModal() {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to create creator");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Add Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "512px" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Creator
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Add a new model creator (the organization that made/trained the
model).
</div>
<form action={handleSubmit} className="space-y-4">
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Name (slug) <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
<p className="text-xs text-muted-foreground">
Lowercase identifier (e.g., openai, meta, anthropic)
</p>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={2}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Creator of GPT models..."
/>
</div>
<div className="space-y-2">
<label
htmlFor="website_url"
className="text-sm font-medium text-foreground"
>
Website URL
</label>
<input
id="website_url"
name="website_url"
type="url"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="https://openai.com"
/>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Add Creator"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,314 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import { createLlmModelAction } from "../actions";
import { useRouter } from "next/navigation";
interface Props {
providers: LlmProvider[];
creators: LlmModelCreator[];
}
export function AddModelModal({ providers, creators }: Props) {
const [open, setOpen] = useState(false);
const [selectedCreatorId, setSelectedCreatorId] = useState("");
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to create model");
} finally {
setIsSubmitting(false);
}
}
// When provider changes, auto-select matching creator if one exists
function handleProviderChange(providerId: string) {
const provider = providers.find((p) => p.id === providerId);
if (provider) {
// Find creator with same name as provider (e.g., "openai" -> "openai")
const matchingCreator = creators.find((c) => c.name === provider.name);
if (matchingCreator) {
setSelectedCreatorId(matchingCreator.id);
} else {
// No matching creator (e.g., OpenRouter hosts other creators' models)
setSelectedCreatorId("");
}
}
}
return (
<Dialog
title="Add Model"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Model
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Register a new model slug, metadata, and pricing.
</div>
<form action={handleSubmit} className="space-y-6">
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core model details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="slug"
className="text-sm font-medium text-foreground"
>
Model Slug <span className="text-destructive">*</span>
</label>
<input
id="slug"
required
name="slug"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="gpt-4.1-mini-2025-04-14"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="GPT 4.1 Mini"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Model Configuration */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Model Configuration
</h3>
<p className="text-xs text-muted-foreground">
Model capabilities and limits
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="provider_id"
className="text-sm font-medium text-foreground"
>
Provider <span className="text-destructive">*</span>
</label>
<select
id="provider_id"
required
name="provider_id"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
defaultValue=""
onChange={(e) => handleProviderChange(e.target.value)}
>
<option value="" disabled>
Select provider
</option>
{providers.map((provider) => (
<option key={provider.id} value={provider.id}>
{provider.display_name} ({provider.name})
</option>
))}
</select>
<p className="text-xs text-muted-foreground">
Who hosts/serves the model
</p>
</div>
<div className="space-y-2">
<label
htmlFor="creator_id"
className="text-sm font-medium text-foreground"
>
Creator
</label>
<select
id="creator_id"
name="creator_id"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
value={selectedCreatorId}
onChange={(e) => setSelectedCreatorId(e.target.value)}
>
<option value="">No creator selected</option>
{creators.map((creator) => (
<option key={creator.id} value={creator.id}>
{creator.display_name} ({creator.name})
</option>
))}
</select>
<p className="text-xs text-muted-foreground">
Who made/trained the model (e.g., OpenAI, Meta)
</p>
</div>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="context_window"
className="text-sm font-medium text-foreground"
>
Context Window <span className="text-destructive">*</span>
</label>
<input
id="context_window"
required
type="number"
name="context_window"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="128000"
min={1}
/>
</div>
<div className="space-y-2">
<label
htmlFor="max_output_tokens"
className="text-sm font-medium text-foreground"
>
Max Output Tokens
</label>
<input
id="max_output_tokens"
type="number"
name="max_output_tokens"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="16384"
min={1}
/>
</div>
</div>
</div>
{/* Pricing */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">Pricing</h3>
<p className="text-xs text-muted-foreground">
Credit cost per run (credentials are managed via the provider)
</p>
</div>
<div className="grid gap-4 sm:grid-cols-1">
<div className="space-y-2">
<label
htmlFor="credit_cost"
className="text-sm font-medium text-foreground"
>
Credit Cost <span className="text-destructive">*</span>
</label>
<input
id="credit_cost"
required
type="number"
name="credit_cost"
step="1"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="5"
min={0}
/>
</div>
</div>
<p className="text-xs text-muted-foreground">
Credit cost is always in platform credits. Credentials are
inherited from the selected provider.
</p>
</div>
{/* Enabled Toggle */}
<div className="flex items-center gap-3 border-t border-border pt-6">
<input type="hidden" name="is_enabled" value="off" />
<input
id="is_enabled"
type="checkbox"
name="is_enabled"
defaultChecked
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor="is_enabled"
className="text-sm font-medium text-foreground"
>
Enabled by default
</label>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Save Model"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,268 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { createLlmProviderAction } from "../actions";
import { useRouter } from "next/navigation";
export function AddProviderModal() {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmProviderAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to create provider",
);
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Add Provider"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Provider
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Define a new upstream provider and default credential information.
</div>
{/* Setup Instructions */}
<div className="mb-6 rounded-lg border border-primary/30 bg-primary/5 p-4">
<div className="space-y-2">
<h4 className="text-sm font-semibold text-foreground">
Before Adding a Provider
</h4>
<p className="text-xs text-muted-foreground">
To use a new provider, you must first configure its credentials in
the backend:
</p>
<ol className="list-inside list-decimal space-y-1 text-xs text-muted-foreground">
<li>
Add the credential to{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
backend/integrations/credentials_store.py
</code>{" "}
with a UUID, provider name, and settings secret reference
</li>
<li>
Add it to the{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
PROVIDER_CREDENTIALS
</code>{" "}
dictionary in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
backend/data/block_cost_config.py
</code>
</li>
<li>
Use the <strong>same provider name</strong> in the
&quot;Credential Provider&quot; field below that matches the key
in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
PROVIDER_CREDENTIALS
</code>
</li>
</ol>
</div>
</div>
<form action={handleSubmit} className="space-y-6">
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core provider details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Provider Slug <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="e.g. openai"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Default Credentials */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Default Credentials
</h3>
<p className="text-xs text-muted-foreground">
Credential provider name that matches the key in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>
</p>
</div>
<div className="space-y-2">
<label
htmlFor="default_credential_provider"
className="text-sm font-medium text-foreground"
>
Credential Provider <span className="text-destructive">*</span>
</label>
<input
id="default_credential_provider"
name="default_credential_provider"
required
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
<p className="text-xs text-muted-foreground">
<strong>Important:</strong> This must exactly match the key in
the{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>{" "}
dictionary in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
block_cost_config.py
</code>
. Common values: &quot;openai&quot;, &quot;anthropic&quot;,
&quot;groq&quot;, &quot;open_router&quot;, etc.
</p>
</div>
</div>
{/* Capabilities */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Capabilities
</h3>
<p className="text-xs text-muted-foreground">
Provider feature flags
</p>
</div>
<div className="grid gap-3 sm:grid-cols-2">
{[
{ name: "supports_tools", label: "Supports tools" },
{ name: "supports_json_output", label: "Supports JSON output" },
{ name: "supports_reasoning", label: "Supports reasoning" },
{
name: "supports_parallel_tool",
label: "Supports parallel tool calls",
},
].map(({ name, label }) => (
<div
key={name}
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
>
<input type="hidden" name={name} value="off" />
<input
id={name}
type="checkbox"
name={name}
defaultChecked={
name !== "supports_reasoning" &&
name !== "supports_parallel_tool"
}
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor={name}
className="text-sm font-medium text-foreground"
>
{label}
</label>
</div>
))}
</div>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Save Provider"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,195 +0,0 @@
"use client";
import { useState } from "react";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { Button } from "@/components/atoms/Button/Button";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { updateLlmCreatorAction } from "../actions";
import { useRouter } from "next/navigation";
import { DeleteCreatorModal } from "./DeleteCreatorModal";
export function CreatorsTable({ creators }: { creators: LlmModelCreator[] }) {
if (!creators.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No creators registered yet.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Creator</TableHead>
<TableHead>Description</TableHead>
<TableHead>Website</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{creators.map((creator) => (
<TableRow key={creator.id}>
<TableCell>
<div className="font-medium">{creator.display_name}</div>
<div className="text-xs text-muted-foreground">
{creator.name}
</div>
</TableCell>
<TableCell>
<span className="text-sm text-muted-foreground">
{creator.description || "—"}
</span>
</TableCell>
<TableCell>
{creator.website_url ? (
<a
href={creator.website_url}
target="_blank"
rel="noopener noreferrer"
className="text-sm text-primary hover:underline"
>
{(() => {
try {
return new URL(creator.website_url).hostname;
} catch {
return creator.website_url;
}
})()}
</a>
) : (
<span className="text-muted-foreground"></span>
)}
</TableCell>
<TableCell>
<div className="flex items-center justify-end gap-2">
<EditCreatorModal creator={creator} />
<DeleteCreatorModal creator={creator} />
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
);
}
function EditCreatorModal({ creator }: { creator: LlmModelCreator }) {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to update creator");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "512px" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small" className="min-w-0">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<form action={handleSubmit} className="space-y-4">
<input type="hidden" name="creator_id" value={creator.id} />
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label className="text-sm font-medium">Name (slug)</label>
<input
required
name="name"
defaultValue={creator.name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Display Name</label>
<input
required
name="display_name"
defaultValue={creator.display_name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Description</label>
<textarea
name="description"
rows={2}
defaultValue={creator.description ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Website URL</label>
<input
name="website_url"
type="url"
defaultValue={creator.website_url ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Updating..." : "Update"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,107 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import { deleteLlmCreatorAction } from "../actions";
export function DeleteCreatorModal({ creator }: { creator: LlmModelCreator }) {
const [open, setOpen] = useState(false);
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to delete creator");
} finally {
setIsDeleting(false);
}
}
return (
<Dialog
title="Delete Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "480px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{creator.display_name}</span>{" "}
<span className="text-muted-foreground">
({creator.name})
</span>
</p>
<p className="mt-2 text-muted-foreground">
Models using this creator will have their creator field
cleared. This is safe and won&apos;t affect model
functionality.
</p>
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="creator_id" value={creator.id} />
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isDeleting}
type="button"
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={isDeleting}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
{isDeleting ? "Deleting..." : "Delete Creator"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,224 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { deleteLlmModelAction, fetchLlmModelUsage } from "../actions";
export function DeleteModelModal({
model,
availableModels,
}: {
model: LlmModel;
availableModels: LlmModel[];
}) {
const router = useRouter();
const [open, setOpen] = useState(false);
const [selectedReplacement, setSelectedReplacement] = useState<string>("");
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const [usageCount, setUsageCount] = useState<number | null>(null);
const [usageLoading, setUsageLoading] = useState(false);
const [usageError, setUsageError] = useState<string | null>(null);
// Filter out the current model and disabled models from replacement options
const replacementOptions = availableModels.filter(
(m) => m.id !== model.id && m.is_enabled,
);
// Check if migration is required (has blocks using this model)
const requiresMigration = usageCount !== null && usageCount > 0;
async function fetchUsage() {
setUsageLoading(true);
setUsageError(null);
try {
const usage = await fetchLlmModelUsage(model.id);
setUsageCount(usage.node_count);
} catch (err) {
console.error("Failed to fetch model usage:", err);
setUsageError("Failed to load usage count");
setUsageCount(null);
} finally {
setUsageLoading(false);
}
}
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to delete model");
} finally {
setIsDeleting(false);
}
}
// Determine if delete button should be enabled
const canDelete =
!isDeleting &&
!usageLoading &&
usageCount !== null &&
(requiresMigration
? selectedReplacement && replacementOptions.length > 0
: true);
return (
<Dialog
title="Delete Model"
controlled={{
isOpen: open,
set: async (isOpen) => {
setOpen(isOpen);
if (isOpen) {
setUsageCount(null);
setUsageError(null);
setError(null);
setSelectedReplacement("");
await fetchUsage();
}
},
}}
styling={{ maxWidth: "600px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
{requiresMigration
? "This action cannot be undone. All workflows using this model will be migrated to the replacement model you select."
: "This action cannot be undone."}
</div>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{model.display_name}</span>{" "}
<span className="text-muted-foreground">({model.slug})</span>
</p>
{usageLoading && (
<p className="mt-2 text-muted-foreground">
Loading usage count...
</p>
)}
{usageError && (
<p className="mt-2 text-destructive">{usageError}</p>
)}
{!usageLoading && !usageError && usageCount !== null && (
<p className="mt-2 font-semibold">
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
currently use this model
</p>
)}
{requiresMigration && (
<p className="mt-2 text-muted-foreground">
All workflows currently using this model will be
automatically updated to use the replacement model you
choose below.
</p>
)}
{!usageLoading && usageCount === 0 && (
<p className="mt-2 text-muted-foreground">
No workflows are using this model. It can be safely deleted.
</p>
)}
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<input
type="hidden"
name="replacement_model_slug"
value={selectedReplacement}
/>
{requiresMigration && (
<label className="text-sm font-medium">
<span className="mb-2 block">
Select Replacement Model{" "}
<span className="text-destructive">*</span>
</span>
<select
required
value={selectedReplacement}
onChange={(e) => setSelectedReplacement(e.target.value)}
className="w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="">-- Choose a replacement model --</option>
{replacementOptions.map((m) => (
<option key={m.id} value={m.slug}>
{m.display_name} ({m.slug})
</option>
))}
</select>
{replacementOptions.length === 0 && (
<p className="mt-2 text-xs text-destructive">
No replacement models available. You must have at least one
other enabled model before deleting this one.
</p>
)}
</label>
)}
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setSelectedReplacement("");
setError(null);
}}
disabled={isDeleting}
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={!canDelete}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
{isDeleting
? "Deleting..."
: requiresMigration
? "Delete and Migrate"
: "Delete"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,129 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { deleteLlmProviderAction } from "../actions";
export function DeleteProviderModal({ provider }: { provider: LlmProvider }) {
const [open, setOpen] = useState(false);
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
const modelCount = provider.models?.length ?? 0;
const hasModels = modelCount > 0;
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmProviderAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to delete provider",
);
} finally {
setIsDeleting(false);
}
}
return (
<Dialog
title="Delete Provider"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "480px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="space-y-4">
<div
className={`rounded-lg border p-4 ${
hasModels
? "border-destructive/30 bg-destructive/10"
: "border-amber-500/30 bg-amber-500/10 dark:border-amber-400/30 dark:bg-amber-400/10"
}`}
>
<div className="flex items-start gap-3">
<div
className={`flex-shrink-0 ${
hasModels
? "text-destructive"
: "text-amber-600 dark:text-amber-400"
}`}
>
{hasModels ? "🚫" : "⚠️"}
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{provider.display_name}</span>{" "}
<span className="text-muted-foreground">
({provider.name})
</span>
</p>
{hasModels ? (
<p className="mt-2 text-destructive">
This provider has {modelCount} model(s). You must delete all
models before you can delete this provider.
</p>
) : (
<p className="mt-2 text-muted-foreground">
This provider has no models and can be safely deleted.
</p>
)}
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="provider_id" value={provider.id} />
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isDeleting}
type="button"
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={isDeleting || hasModels}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90 disabled:opacity-50"
>
{isDeleting ? "Deleting..." : "Delete Provider"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,288 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { toggleLlmModelAction, fetchLlmModelUsage } from "../actions";
export function DisableModelModal({
model,
availableModels,
}: {
model: LlmModel;
availableModels: LlmModel[];
}) {
const [open, setOpen] = useState(false);
const [isDisabling, setIsDisabling] = useState(false);
const [error, setError] = useState<string | null>(null);
const [usageCount, setUsageCount] = useState<number | null>(null);
const [selectedMigration, setSelectedMigration] = useState<string>("");
const [wantsMigration, setWantsMigration] = useState(false);
const [migrationReason, setMigrationReason] = useState("");
const [customCreditCost, setCustomCreditCost] = useState<string>("");
// Filter out the current model and disabled models from replacement options
const migrationOptions = availableModels.filter(
(m) => m.id !== model.id && m.is_enabled,
);
async function fetchUsage() {
try {
const usage = await fetchLlmModelUsage(model.id);
setUsageCount(usage.node_count);
} catch {
setUsageCount(null);
}
}
async function handleDisable(formData: FormData) {
setIsDisabling(true);
setError(null);
try {
await toggleLlmModelAction(formData);
setOpen(false);
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to disable model");
} finally {
setIsDisabling(false);
}
}
function resetState() {
setError(null);
setSelectedMigration("");
setWantsMigration(false);
setMigrationReason("");
setCustomCreditCost("");
}
const hasUsage = usageCount !== null && usageCount > 0;
return (
<Dialog
title="Disable Model"
controlled={{
isOpen: open,
set: async (isOpen) => {
setOpen(isOpen);
if (isOpen) {
setUsageCount(null);
resetState();
await fetchUsage();
}
},
}}
styling={{ maxWidth: "600px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0"
>
Disable
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Disabling a model will hide it from users when creating new workflows.
</div>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to disable:</p>
<p className="mt-1">
<span className="font-medium">{model.display_name}</span>{" "}
<span className="text-muted-foreground">({model.slug})</span>
</p>
{usageCount === null ? (
<p className="mt-2 text-muted-foreground">
Loading usage data...
</p>
) : usageCount > 0 ? (
<p className="mt-2 font-semibold">
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
currently use this model
</p>
) : (
<p className="mt-2 text-muted-foreground">
No workflows are currently using this model.
</p>
)}
</div>
</div>
</div>
{hasUsage && (
<div className="space-y-4 rounded-lg border border-border bg-muted/50 p-4">
<label className="flex items-start gap-3">
<input
type="checkbox"
checked={wantsMigration}
onChange={(e) => {
setWantsMigration(e.target.checked);
if (!e.target.checked) {
setSelectedMigration("");
}
}}
className="mt-1"
/>
<div className="text-sm">
<span className="font-medium">
Migrate existing workflows to another model
</span>
<p className="mt-1 text-muted-foreground">
Creates a revertible migration record. If unchecked,
existing workflows will use automatic fallback to an enabled
model from the same provider.
</p>
</div>
</label>
{wantsMigration && (
<div className="space-y-4 border-t border-border pt-4">
<label className="block text-sm font-medium">
<span className="mb-2 block">
Replacement Model{" "}
<span className="text-destructive">*</span>
</span>
<select
required
value={selectedMigration}
onChange={(e) => setSelectedMigration(e.target.value)}
className="w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="">-- Choose a replacement model --</option>
{migrationOptions.map((m) => (
<option key={m.id} value={m.slug}>
{m.display_name} ({m.slug})
</option>
))}
</select>
{migrationOptions.length === 0 && (
<p className="mt-2 text-xs text-destructive">
No other enabled models available for migration.
</p>
)}
</label>
<label className="block text-sm font-medium">
<span className="mb-2 block">
Migration Reason{" "}
<span className="font-normal text-muted-foreground">
(optional)
</span>
</span>
<input
type="text"
value={migrationReason}
onChange={(e) => setMigrationReason(e.target.value)}
placeholder="e.g., Provider outage, Cost reduction"
className="w-full rounded border border-input bg-background p-2 text-sm"
/>
<p className="mt-1 text-xs text-muted-foreground">
Helps track why the migration was made
</p>
</label>
<label className="block text-sm font-medium">
<span className="mb-2 block">
Custom Credit Cost{" "}
<span className="font-normal text-muted-foreground">
(optional)
</span>
</span>
<input
type="number"
min="0"
value={customCreditCost}
onChange={(e) => setCustomCreditCost(e.target.value)}
placeholder="Leave blank to use target model's cost"
className="w-full rounded border border-input bg-background p-2 text-sm"
/>
<p className="mt-1 text-xs text-muted-foreground">
Override pricing for migrated workflows. When set, billing
will use this cost instead of the target model&apos;s
cost.
</p>
</label>
</div>
)}
</div>
)}
<form action={handleDisable} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<input type="hidden" name="is_enabled" value="false" />
{wantsMigration && selectedMigration && (
<>
<input
type="hidden"
name="migrate_to_slug"
value={selectedMigration}
/>
{migrationReason && (
<input
type="hidden"
name="migration_reason"
value={migrationReason}
/>
)}
{customCreditCost && (
<input
type="hidden"
name="custom_credit_cost"
value={customCreditCost}
/>
)}
</>
)}
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
resetState();
}}
disabled={isDisabling}
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={
isDisabling ||
(wantsMigration && !selectedMigration) ||
usageCount === null
}
>
{isDisabling
? "Disabling..."
: wantsMigration && selectedMigration
? "Disable & Migrate"
: "Disable Model"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,223 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { updateLlmModelAction } from "../actions";
export function EditModelModal({
model,
providers,
creators,
}: {
model: LlmModel;
providers: LlmProvider[];
creators: LlmModelCreator[];
}) {
const router = useRouter();
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const cost = model.costs?.[0];
const provider = providers.find((p) => p.id === model.provider_id);
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to update model");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Model"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small" className="min-w-0">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Update model metadata and pricing information.
</div>
{error && (
<div className="mb-4 rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<form action={handleSubmit} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Display Name
<input
required
name="display_name"
defaultValue={model.display_name}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
/>
</label>
<label className="text-sm font-medium">
Provider
<select
required
name="provider_id"
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
defaultValue={model.provider_id}
>
{providers.map((p) => (
<option key={p.id} value={p.id}>
{p.display_name} ({p.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Who hosts/serves the model
</span>
</label>
</div>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Creator
<select
name="creator_id"
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
defaultValue={model.creator_id ?? ""}
>
<option value="">No creator selected</option>
{creators.map((c) => (
<option key={c.id} value={c.id}>
{c.display_name} ({c.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Who made/trained the model (e.g., OpenAI, Meta)
</span>
</label>
</div>
<label className="text-sm font-medium">
Description
<textarea
name="description"
rows={2}
defaultValue={model.description ?? ""}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
placeholder="Optional description..."
/>
</label>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Context Window
<input
required
type="number"
name="context_window"
defaultValue={model.context_window}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={1}
/>
</label>
<label className="text-sm font-medium">
Max Output Tokens
<input
type="number"
name="max_output_tokens"
defaultValue={model.max_output_tokens ?? undefined}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={1}
/>
</label>
</div>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Credit Cost
<input
required
type="number"
name="credit_cost"
defaultValue={cost?.credit_cost ?? 0}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={0}
/>
<span className="text-xs text-muted-foreground">
Credits charged per run
</span>
</label>
<label className="text-sm font-medium">
Credential Provider
<select
required
name="credential_provider"
defaultValue={cost?.credential_provider ?? provider?.name ?? ""}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="" disabled>
Select provider
</option>
{providers.map((p) => (
<option key={p.id} value={p.name}>
{p.display_name} ({p.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Must match a key in PROVIDER_CREDENTIALS
</span>
</label>
</div>
{/* Hidden defaults for credential_type and unit */}
<input
type="hidden"
name="credential_type"
value={
cost?.credential_type ??
provider?.default_credential_type ??
"api_key"
}
/>
<input type="hidden" name="unit" value={cost?.unit ?? "RUN"} />
<Dialog.Footer>
<Button
type="button"
variant="ghost"
size="small"
onClick={() => setOpen(false)}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Updating..." : "Update Model"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,263 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { updateLlmProviderAction } from "../actions";
import { useRouter } from "next/navigation";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
export function EditProviderModal({ provider }: { provider: LlmProvider }) {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmProviderAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to update provider",
);
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Provider"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Update provider configuration and capabilities.
</div>
<form action={handleSubmit} className="space-y-6">
<input type="hidden" name="provider_id" value={provider.id} />
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core provider details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Provider Slug <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
defaultValue={provider.name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="e.g. openai"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
defaultValue={provider.display_name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
defaultValue={provider.description ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Default Credentials */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Default Credentials
</h3>
<p className="text-xs text-muted-foreground">
Credential provider name that matches the key in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="default_credential_provider"
className="text-sm font-medium text-foreground"
>
Credential Provider
</label>
<input
id="default_credential_provider"
name="default_credential_provider"
defaultValue={provider.default_credential_provider ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
</div>
<div className="space-y-2">
<label
htmlFor="default_credential_id"
className="text-sm font-medium text-foreground"
>
Credential ID
</label>
<input
id="default_credential_id"
name="default_credential_id"
defaultValue={provider.default_credential_id ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional credential ID"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="default_credential_type"
className="text-sm font-medium text-foreground"
>
Credential Type
</label>
<input
id="default_credential_type"
name="default_credential_type"
defaultValue={provider.default_credential_type ?? "api_key"}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="api_key"
/>
</div>
</div>
{/* Capabilities */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Capabilities
</h3>
<p className="text-xs text-muted-foreground">
Provider feature flags
</p>
</div>
<div className="grid gap-3 sm:grid-cols-2">
{[
{
name: "supports_tools",
label: "Supports tools",
checked: provider.supports_tools,
},
{
name: "supports_json_output",
label: "Supports JSON output",
checked: provider.supports_json_output,
},
{
name: "supports_reasoning",
label: "Supports reasoning",
checked: provider.supports_reasoning,
},
{
name: "supports_parallel_tool",
label: "Supports parallel tool calls",
checked: provider.supports_parallel_tool,
},
].map(({ name, label, checked }) => (
<div
key={name}
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
>
<input type="hidden" name={name} value="off" />
<input
id={name}
type="checkbox"
name={name}
defaultChecked={checked}
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor={name}
className="text-sm font-medium text-foreground"
>
{label}
</label>
</div>
))}
</div>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Saving..." : "Save Changes"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,131 +0,0 @@
"use client";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { ErrorBoundary } from "@/components/molecules/ErrorBoundary/ErrorBoundary";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { AddProviderModal } from "./AddProviderModal";
import { AddModelModal } from "./AddModelModal";
import { AddCreatorModal } from "./AddCreatorModal";
import { ProviderList } from "./ProviderList";
import { ModelsTable } from "./ModelsTable";
import { MigrationsTable } from "./MigrationsTable";
import { CreatorsTable } from "./CreatorsTable";
import { RecommendedModelSelector } from "./RecommendedModelSelector";
interface Props {
providers: LlmProvider[];
models: LlmModel[];
migrations: LlmModelMigration[];
creators: LlmModelCreator[];
}
function AdminErrorFallback() {
return (
<div className="mx-auto max-w-xl p-6">
<ErrorCard
responseError={{
message:
"An error occurred while loading the LLM Registry. Please refresh the page.",
}}
context="llm-registry"
onRetry={() => window.location.reload()}
/>
</div>
);
}
export function LlmRegistryDashboard({
providers,
models,
migrations,
creators,
}: Props) {
return (
<ErrorBoundary fallback={<AdminErrorFallback />} context="llm-registry">
<div className="mx-auto p-6">
<div className="flex flex-col gap-6">
{/* Header */}
<div>
<h1 className="text-3xl font-bold">LLM Registry</h1>
<p className="text-muted-foreground">
Manage providers, creators, models, and credit pricing
</p>
</div>
{/* Active Migrations Section - Only show if there are migrations */}
{migrations.length > 0 && (
<div className="rounded-lg border border-primary/30 bg-primary/5 p-6 shadow-sm">
<div className="mb-4">
<h2 className="text-xl font-semibold">Active Migrations</h2>
<p className="mt-1 text-sm text-muted-foreground">
These migrations can be reverted to restore workflows to their
original model
</p>
</div>
<MigrationsTable migrations={migrations} />
</div>
)}
{/* Providers & Creators Section - Side by Side */}
<div className="grid gap-6 lg:grid-cols-2">
{/* Providers */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Providers</h2>
<p className="mt-1 text-sm text-muted-foreground">
Who hosts/serves the models
</p>
</div>
<AddProviderModal />
</div>
<ProviderList providers={providers} />
</div>
{/* Creators */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Creators</h2>
<p className="mt-1 text-sm text-muted-foreground">
Who made/trained the models
</p>
</div>
<AddCreatorModal />
</div>
<CreatorsTable creators={creators} />
</div>
</div>
{/* Models Section */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Models</h2>
<p className="mt-1 text-sm text-muted-foreground">
Toggle availability, adjust context windows, and update credit
pricing
</p>
</div>
<AddModelModal providers={providers} creators={creators} />
</div>
{/* Recommended Model Selector */}
<div className="mb-6">
<RecommendedModelSelector models={models} />
</div>
<ModelsTable
models={models}
providers={providers}
creators={creators}
/>
</div>
</div>
</div>
</ErrorBoundary>
);
}

View File

@@ -1,133 +0,0 @@
"use client";
import { useState } from "react";
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { revertLlmMigrationAction } from "../actions";
export function MigrationsTable({
migrations,
}: {
migrations: LlmModelMigration[];
}) {
if (!migrations.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No active migrations. Migrations are created when you disable a model
with the &quot;Migrate existing workflows&quot; option.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Migration</TableHead>
<TableHead>Reason</TableHead>
<TableHead>Nodes Affected</TableHead>
<TableHead>Custom Cost</TableHead>
<TableHead>Created</TableHead>
<TableHead className="text-right">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{migrations.map((migration) => (
<MigrationRow key={migration.id} migration={migration} />
))}
</TableBody>
</Table>
</div>
);
}
function MigrationRow({ migration }: { migration: LlmModelMigration }) {
const [isReverting, setIsReverting] = useState(false);
const [error, setError] = useState<string | null>(null);
async function handleRevert(formData: FormData) {
setIsReverting(true);
setError(null);
try {
await revertLlmMigrationAction(formData);
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to revert migration",
);
} finally {
setIsReverting(false);
}
}
const createdDate = new Date(migration.created_at);
return (
<>
<TableRow>
<TableCell>
<div className="text-sm">
<span className="font-medium">{migration.source_model_slug}</span>
<span className="mx-2 text-muted-foreground"></span>
<span className="font-medium">{migration.target_model_slug}</span>
</div>
</TableCell>
<TableCell>
<div className="text-sm text-muted-foreground">
{migration.reason || "—"}
</div>
</TableCell>
<TableCell>
<div className="text-sm">{migration.node_count}</div>
</TableCell>
<TableCell>
<div className="text-sm">
{migration.custom_credit_cost !== null &&
migration.custom_credit_cost !== undefined
? `${migration.custom_credit_cost} credits`
: "—"}
</div>
</TableCell>
<TableCell>
<div className="text-sm text-muted-foreground">
{createdDate.toLocaleDateString()}{" "}
{createdDate.toLocaleTimeString([], {
hour: "2-digit",
minute: "2-digit",
})}
</div>
</TableCell>
<TableCell className="text-right">
<form action={handleRevert} className="inline">
<input type="hidden" name="migration_id" value={migration.id} />
<Button
type="submit"
variant="outline"
size="small"
disabled={isReverting}
>
{isReverting ? "Reverting..." : "Revert"}
</Button>
</form>
</TableCell>
</TableRow>
{error && (
<TableRow>
<TableCell colSpan={6}>
<div className="rounded border border-destructive/30 bg-destructive/10 p-2 text-sm text-destructive">
{error}
</div>
</TableCell>
</TableRow>
)}
</>
);
}

View File

@@ -1,262 +0,0 @@
"use client";
import { useState, useEffect, useRef } from "react";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { Button } from "@/components/atoms/Button/Button";
import { toggleLlmModelAction } from "../actions";
import { DeleteModelModal } from "./DeleteModelModal";
import { DisableModelModal } from "./DisableModelModal";
import { EditModelModal } from "./EditModelModal";
import { Star, Spinner } from "@phosphor-icons/react";
import { getV2ListLlmModels } from "@/app/api/__generated__/endpoints/admin/admin";
const PAGE_SIZE = 50;
export function ModelsTable({
models: initialModels,
providers,
creators,
}: {
models: LlmModel[];
providers: LlmProvider[];
creators: LlmModelCreator[];
}) {
const [models, setModels] = useState<LlmModel[]>(initialModels);
const [currentPage, setCurrentPage] = useState(1);
const [hasMore, setHasMore] = useState(initialModels.length === PAGE_SIZE);
const [isLoading, setIsLoading] = useState(false);
const loadedPagesRef = useRef(1);
// Sync with parent when initialModels changes (e.g., after enable/disable)
// Re-fetch all loaded pages to preserve expanded state
useEffect(() => {
async function refetchAllPages() {
const pagesToLoad = loadedPagesRef.current;
if (pagesToLoad === 1) {
// Only first page loaded, just use initialModels
setModels(initialModels);
setHasMore(initialModels.length === PAGE_SIZE);
return;
}
// Re-fetch all pages we had loaded
const allModels: LlmModel[] = [...initialModels];
let lastPageHadFullResults = initialModels.length === PAGE_SIZE;
for (let page = 2; page <= pagesToLoad; page++) {
try {
const response = await getV2ListLlmModels({
page,
page_size: PAGE_SIZE,
});
if (response.status === 200) {
allModels.push(...response.data.models);
lastPageHadFullResults = response.data.models.length === PAGE_SIZE;
}
} catch (err) {
console.error(`Error refetching page ${page}:`, err);
break;
}
}
setModels(allModels);
setHasMore(lastPageHadFullResults);
}
refetchAllPages();
}, [initialModels]);
async function loadMore() {
if (isLoading) return;
setIsLoading(true);
try {
const nextPage = currentPage + 1;
const response = await getV2ListLlmModels({
page: nextPage,
page_size: PAGE_SIZE,
});
if (response.status === 200) {
setModels((prev) => [...prev, ...response.data.models]);
setCurrentPage(nextPage);
loadedPagesRef.current = nextPage;
setHasMore(response.data.models.length === PAGE_SIZE);
}
} catch (err) {
console.error("Error loading more models:", err);
} finally {
setIsLoading(false);
}
}
if (!models.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No models registered yet.
</div>
);
}
const providerLookup = new Map(
providers.map((provider) => [provider.id, provider]),
);
return (
<div>
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Model</TableHead>
<TableHead>Provider</TableHead>
<TableHead>Creator</TableHead>
<TableHead>Context Window</TableHead>
<TableHead>Max Output</TableHead>
<TableHead>Cost</TableHead>
<TableHead>Status</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{models.map((model) => {
const cost = model.costs?.[0];
const provider = providerLookup.get(model.provider_id);
return (
<TableRow
key={model.id}
className={model.is_enabled ? "" : "opacity-60"}
>
<TableCell>
<div className="font-medium">{model.display_name}</div>
<div className="text-xs text-muted-foreground">
{model.slug}
</div>
</TableCell>
<TableCell>
{provider ? (
<>
<div>{provider.display_name}</div>
<div className="text-xs text-muted-foreground">
{provider.name}
</div>
</>
) : (
model.provider_id
)}
</TableCell>
<TableCell>
{model.creator ? (
<>
<div>{model.creator.display_name}</div>
<div className="text-xs text-muted-foreground">
{model.creator.name}
</div>
</>
) : (
<span className="text-muted-foreground"></span>
)}
</TableCell>
<TableCell>{model.context_window.toLocaleString()}</TableCell>
<TableCell>
{model.max_output_tokens
? model.max_output_tokens.toLocaleString()
: "—"}
</TableCell>
<TableCell>
{cost ? (
<>
<div className="font-medium">
{cost.credit_cost} credits
</div>
<div className="text-xs text-muted-foreground">
{cost.credential_provider}
</div>
</>
) : (
"—"
)}
</TableCell>
<TableCell>
<div className="flex flex-col gap-1">
<span
className={`inline-flex rounded-full px-2.5 py-1 text-xs font-semibold ${
model.is_enabled
? "bg-primary/10 text-primary"
: "bg-muted text-muted-foreground"
}`}
>
{model.is_enabled ? "Enabled" : "Disabled"}
</span>
{model.is_recommended && (
<span className="inline-flex items-center gap-1 rounded-full bg-amber-500/10 px-2.5 py-1 text-xs font-semibold text-amber-600 dark:text-amber-400">
<Star size={12} weight="fill" />
Recommended
</span>
)}
</div>
</TableCell>
<TableCell>
<div className="flex items-center justify-end gap-2">
{model.is_enabled ? (
<DisableModelModal
model={model}
availableModels={models}
/>
) : (
<EnableModelButton modelId={model.id} />
)}
<EditModelModal
model={model}
providers={providers}
creators={creators}
/>
<DeleteModelModal model={model} availableModels={models} />
</div>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
{hasMore && (
<div className="mt-4 flex justify-center">
<Button onClick={loadMore} disabled={isLoading} variant="outline">
{isLoading ? (
<>
<Spinner className="mr-2 h-4 w-4 animate-spin" />
Loading...
</>
) : (
"Load More"
)}
</Button>
</div>
)}
</div>
);
}
function EnableModelButton({ modelId }: { modelId: string }) {
return (
<form action={toggleLlmModelAction} className="inline">
<input type="hidden" name="model_id" value={modelId} />
<input type="hidden" name="is_enabled" value="true" />
<Button type="submit" variant="outline" size="small" className="min-w-0">
Enable
</Button>
</form>
);
}

View File

@@ -1,94 +0,0 @@
"use client";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { DeleteProviderModal } from "./DeleteProviderModal";
import { EditProviderModal } from "./EditProviderModal";
export function ProviderList({ providers }: { providers: LlmProvider[] }) {
if (!providers.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No providers configured yet.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>Display Name</TableHead>
<TableHead>Default Credential</TableHead>
<TableHead>Capabilities</TableHead>
<TableHead>Models</TableHead>
<TableHead className="w-[100px]">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{providers.map((provider) => (
<TableRow key={provider.id}>
<TableCell className="font-medium">{provider.name}</TableCell>
<TableCell>{provider.display_name}</TableCell>
<TableCell>
{provider.default_credential_provider
? `${provider.default_credential_provider} (${provider.default_credential_id ?? "id?"})`
: "—"}
</TableCell>
<TableCell className="text-sm text-muted-foreground">
<div className="flex flex-wrap gap-2">
{provider.supports_tools && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
Tools
</span>
)}
{provider.supports_json_output && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
JSON
</span>
)}
{provider.supports_reasoning && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
Reasoning
</span>
)}
{provider.supports_parallel_tool && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
Parallel Tools
</span>
)}
</div>
</TableCell>
<TableCell className="text-sm">
<span
className={
(provider.models?.length ?? 0) > 0
? "text-foreground"
: "text-muted-foreground"
}
>
{provider.models?.length ?? 0}
</span>
</TableCell>
<TableCell>
<div className="flex gap-2">
<EditProviderModal provider={provider} />
<DeleteProviderModal provider={provider} />
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
);
}

View File

@@ -1,87 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { Button } from "@/components/atoms/Button/Button";
import { setRecommendedModelAction } from "../actions";
import { Star } from "@phosphor-icons/react";
export function RecommendedModelSelector({ models }: { models: LlmModel[] }) {
const router = useRouter();
const enabledModels = models.filter((m) => m.is_enabled);
const currentRecommended = models.find((m) => m.is_recommended);
const [selectedModelId, setSelectedModelId] = useState<string>(
currentRecommended?.id || "",
);
const [isSaving, setIsSaving] = useState(false);
const [error, setError] = useState<string | null>(null);
const hasChanges = selectedModelId !== (currentRecommended?.id || "");
async function handleSave() {
if (!selectedModelId) return;
setIsSaving(true);
setError(null);
try {
const formData = new FormData();
formData.set("model_id", selectedModelId);
await setRecommendedModelAction(formData);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to save");
} finally {
setIsSaving(false);
}
}
return (
<div className="rounded-lg border border-border bg-card p-4">
<div className="mb-3 flex items-center gap-2">
<Star size={20} weight="fill" className="text-amber-500" />
<h3 className="text-sm font-semibold">Recommended Model</h3>
</div>
<p className="mb-3 text-xs text-muted-foreground">
The recommended model is shown as the default suggestion in model
selection dropdowns throughout the platform.
</p>
<div className="flex items-center gap-3">
<select
value={selectedModelId}
onChange={(e) => setSelectedModelId(e.target.value)}
className="flex-1 rounded-md border border-input bg-background px-3 py-2 text-sm"
disabled={isSaving}
>
<option value="">-- Select a model --</option>
{enabledModels.map((model) => (
<option key={model.id} value={model.id}>
{model.display_name} ({model.slug})
</option>
))}
</select>
<Button
type="button"
variant="primary"
size="small"
onClick={handleSave}
disabled={!hasChanges || !selectedModelId || isSaving}
>
{isSaving ? "Saving..." : "Save"}
</Button>
</div>
{error && <p className="mt-2 text-xs text-destructive">{error}</p>}
{currentRecommended && !hasChanges && (
<p className="mt-2 text-xs text-muted-foreground">
Currently set to:{" "}
<span className="font-medium">{currentRecommended.display_name}</span>
</p>
)}
</div>
);
}

View File

@@ -1,46 +0,0 @@
/**
* Server-side data fetching for LLM Registry page.
*/
import {
fetchLlmCreators,
fetchLlmMigrations,
fetchLlmModels,
fetchLlmProviders,
} from "./actions";
export async function getLlmRegistryPageData() {
// Fetch providers and models (required)
const [providersResponse, modelsResponse] = await Promise.all([
fetchLlmProviders(),
fetchLlmModels(),
]);
// Fetch migrations separately with fallback (table might not exist yet)
let migrations: Awaited<ReturnType<typeof fetchLlmMigrations>>["migrations"] =
[];
try {
const migrationsResponse = await fetchLlmMigrations(false);
migrations = migrationsResponse.migrations;
} catch {
// Migrations table might not exist yet - that's ok, just show empty list
console.warn("Could not fetch migrations - table may not exist yet");
}
// Fetch creators separately with fallback (table might not exist yet)
let creators: Awaited<ReturnType<typeof fetchLlmCreators>>["creators"] = [];
try {
const creatorsResponse = await fetchLlmCreators();
creators = creatorsResponse.creators;
} catch {
// Creators table might not exist yet - that's ok, just show empty list
console.warn("Could not fetch creators - table may not exist yet");
}
return {
providers: providersResponse.providers,
models: modelsResponse.models,
migrations,
creators,
};
}

View File

@@ -1,14 +0,0 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { getLlmRegistryPageData } from "./getLlmRegistryPage";
import { LlmRegistryDashboard } from "./components/LlmRegistryDashboard";
async function LlmRegistryPage() {
const data = await getLlmRegistryPageData();
return <LlmRegistryDashboard {...data} />;
}
export default async function AdminLlmRegistryPage() {
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedLlmRegistryPage = await withAdminAccess(LlmRegistryPage);
return <ProtectedLlmRegistryPage />;
}

View File

@@ -38,8 +38,12 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
return outputNodes
.map((node) => {
const executionResult = node.data.nodeExecutionResult;
const outputData = executionResult?.output_data?.output;
const executionResults = node.data.nodeExecutionResults || [];
const latestResult =
executionResults.length > 0
? executionResults[executionResults.length - 1]
: undefined;
const outputData = latestResult?.output_data?.output;
const renderer = globalRegistry.getRenderer(outputData);

View File

@@ -153,6 +153,9 @@ export const useRunInputDialog = ({
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
);
useNodeStore.getState().clearAllNodeExecutionResults();
useNodeStore.getState().cleanNodesStatuses();
await executeGraph({
graphId: flowID ?? "",
graphVersion: flowVersion || null,

View File

@@ -86,7 +86,6 @@ export function FloatingSafeModeToggle({
const {
currentHITLSafeMode,
showHITLToggle,
isHITLStateUndetermined,
handleHITLToggle,
currentSensitiveActionSafeMode,
showSensitiveActionToggle,
@@ -99,16 +98,9 @@ export function FloatingSafeModeToggle({
return null;
}
const showHITL = showHITLToggle && !isHITLStateUndetermined;
const showSensitive = showSensitiveActionToggle;
if (!showHITL && !showSensitive) {
return null;
}
return (
<div className={cn("fixed z-50 flex flex-col gap-2", className)}>
{showHITL && (
{showHITLToggle && (
<SafeModeButton
isEnabled={currentHITLSafeMode}
label="Human in the loop block approval"
@@ -119,7 +111,7 @@ export function FloatingSafeModeToggle({
fullWidth={fullWidth}
/>
)}
{showSensitive && (
{showSensitiveActionToggle && (
<SafeModeButton
isEnabled={currentSensitiveActionSafeMode}
label="Sensitive actions blocks approval"

View File

@@ -34,7 +34,7 @@ export type CustomNodeData = {
uiType: BlockUIType;
block_id: string;
status?: AgentExecutionStatus;
nodeExecutionResult?: NodeExecutionResult;
nodeExecutionResults?: NodeExecutionResult[];
staticOutput?: boolean;
// TODO : We need better type safety for the following backend fields.
costs: BlockCost[];
@@ -75,7 +75,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
(value) => value !== null && value !== undefined && value !== "",
);
const outputData = data.nodeExecutionResult?.output_data;
const latestResult =
data.nodeExecutionResults && data.nodeExecutionResults.length > 0
? data.nodeExecutionResults[data.nodeExecutionResults.length - 1]
: undefined;
const outputData = latestResult?.output_data;
const hasOutputError =
typeof outputData === "object" &&
outputData !== null &&

View File

@@ -14,10 +14,15 @@ import { useNodeOutput } from "./useNodeOutput";
import { ViewMoreData } from "./components/ViewMoreData";
export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
const { outputData, copiedKey, handleCopy, executionResultId, inputData } =
useNodeOutput(nodeId);
const {
latestOutputData,
copiedKey,
handleCopy,
executionResultId,
latestInputData,
} = useNodeOutput(nodeId);
if (Object.keys(outputData).length === 0) {
if (Object.keys(latestOutputData).length === 0) {
return null;
}
@@ -41,18 +46,19 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
<div className="space-y-2">
<Text variant="small-medium">Input</Text>
<ContentRenderer value={inputData} shortContent={false} />
<ContentRenderer value={latestInputData} shortContent={false} />
<div className="mt-1 flex justify-end gap-1">
<NodeDataViewer
data={inputData}
pinName="Input"
nodeId={nodeId}
execId={executionResultId}
dataType="input"
/>
<Button
variant="secondary"
size="small"
onClick={() => handleCopy("input", inputData)}
onClick={() => handleCopy("input", latestInputData)}
className={cn(
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
copiedKey === "input" &&
@@ -68,70 +74,72 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
</div>
</div>
{Object.entries(outputData)
{Object.entries(latestOutputData)
.slice(0, 2)
.map(([key, value]) => (
<div key={key} className="flex flex-col gap-2">
<div className="flex items-center gap-2">
<Text
variant="small-medium"
className="!font-semibold text-slate-600"
>
Pin:
</Text>
<Text variant="small" className="text-slate-700">
{beautifyString(key)}
</Text>
</div>
<div className="w-full space-y-2">
<Text
variant="small"
className="!font-semibold text-slate-600"
>
Data:
</Text>
<div className="relative space-y-2">
{value.map((item, index) => (
<div key={index}>
<ContentRenderer value={item} shortContent={true} />
</div>
))}
.map(([key, value]) => {
return (
<div key={key} className="flex flex-col gap-2">
<div className="flex items-center gap-2">
<Text
variant="small-medium"
className="!font-semibold text-slate-600"
>
Pin:
</Text>
<Text variant="small" className="text-slate-700">
{beautifyString(key)}
</Text>
</div>
<div className="w-full space-y-2">
<Text
variant="small"
className="!font-semibold text-slate-600"
>
Data:
</Text>
<div className="relative space-y-2">
{value.map((item, index) => (
<div key={index}>
<ContentRenderer
value={item}
shortContent={true}
/>
</div>
))}
<div className="mt-1 flex justify-end gap-1">
<NodeDataViewer
data={value}
pinName={key}
execId={executionResultId}
/>
<Button
variant="secondary"
size="small"
onClick={() => handleCopy(key, value)}
className={cn(
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
copiedKey === key &&
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
)}
>
{copiedKey === key ? (
<CheckIcon size={12} className="text-green-600" />
) : (
<CopyIcon size={12} />
)}
</Button>
<div className="mt-1 flex justify-end gap-1">
<NodeDataViewer
pinName={key}
nodeId={nodeId}
execId={executionResultId}
/>
<Button
variant="secondary"
size="small"
onClick={() => handleCopy(key, value)}
className={cn(
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
copiedKey === key &&
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
)}
>
{copiedKey === key ? (
<CheckIcon
size={12}
className="text-green-600"
/>
) : (
<CopyIcon size={12} />
)}
</Button>
</div>
</div>
</div>
</div>
</div>
))}
);
})}
</div>
{Object.keys(outputData).length > 2 && (
<ViewMoreData
outputData={outputData}
execId={executionResultId}
/>
)}
<ViewMoreData nodeId={nodeId} />
</AccordionContent>
</AccordionItem>
</Accordion>

View File

@@ -19,22 +19,51 @@ import {
CopyIcon,
DownloadIcon,
} from "@phosphor-icons/react";
import { FC } from "react";
import React, { FC } from "react";
import { useNodeDataViewer } from "./useNodeDataViewer";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { useShallow } from "zustand/react/shallow";
import { NodeDataType } from "../../helpers";
interface NodeDataViewerProps {
data: any;
export interface NodeDataViewerProps {
data?: any;
pinName: string;
nodeId?: string;
execId?: string;
isViewMoreData?: boolean;
dataType?: NodeDataType;
}
export const NodeDataViewer: FC<NodeDataViewerProps> = ({
data,
pinName,
nodeId,
execId = "N/A",
isViewMoreData = false,
dataType = "output",
}) => {
const executionResults = useNodeStore(
useShallow((state) =>
nodeId ? state.getNodeExecutionResults(nodeId) : [],
),
);
const latestInputData = useNodeStore(
useShallow((state) =>
nodeId ? state.getLatestNodeInputData(nodeId) : undefined,
),
);
const accumulatedOutputData = useNodeStore(
useShallow((state) =>
nodeId ? state.getAccumulatedNodeOutputData(nodeId) : {},
),
);
const resolvedData =
data ??
(dataType === "input"
? (latestInputData ?? {})
: (accumulatedOutputData[pinName] ?? []));
const {
outputItems,
copyExecutionId,
@@ -42,7 +71,20 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
handleDownloadItem,
dataArray,
copiedIndex,
} = useNodeDataViewer(data, pinName, execId);
groupedExecutions,
totalGroupedItems,
handleCopyGroupedItem,
handleDownloadGroupedItem,
copiedKey,
} = useNodeDataViewer(
resolvedData,
pinName,
execId,
executionResults,
dataType,
);
const shouldGroupExecutions = groupedExecutions.length > 0;
return (
<Dialog styling={{ width: "600px" }}>
<TooltipProvider>
@@ -68,44 +110,141 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
<div className="flex items-center gap-4">
<div className="flex items-center gap-2">
<Text variant="large-medium" className="text-slate-900">
Full Output Preview
Full {dataType === "input" ? "Input" : "Output"} Preview
</Text>
</div>
<div className="rounded-full border border-slate-300 bg-slate-100 px-3 py-1.5 text-xs font-medium text-black">
{dataArray.length} item{dataArray.length !== 1 ? "s" : ""} total
{shouldGroupExecutions ? totalGroupedItems : dataArray.length}{" "}
item
{shouldGroupExecutions
? totalGroupedItems !== 1
? "s"
: ""
: dataArray.length !== 1
? "s"
: ""}{" "}
total
</div>
</div>
<div className="text-sm text-gray-600">
<div className="flex items-center gap-2">
<Text variant="body" className="text-slate-600">
Execution ID:
</Text>
<Text
variant="body-medium"
className="rounded-full border border-gray-300 bg-gray-50 px-2 py-1 font-mono text-xs"
>
{execId}
</Text>
<Button
variant="ghost"
size="small"
onClick={copyExecutionId}
className="h-6 w-6 min-w-0 p-0"
>
<CopyIcon size={14} />
</Button>
</div>
<div className="mt-2">
Pin:{" "}
<span className="font-semibold">{beautifyString(pinName)}</span>
</div>
{shouldGroupExecutions ? (
<div>
Pin:{" "}
<span className="font-semibold">{beautifyString(pinName)}</span>
</div>
) : (
<>
<div className="flex items-center gap-2">
<Text variant="body" className="text-slate-600">
Execution ID:
</Text>
<Text
variant="body-medium"
className="rounded-full border border-gray-300 bg-gray-50 px-2 py-1 font-mono text-xs"
>
{execId}
</Text>
<Button
variant="ghost"
size="small"
onClick={copyExecutionId}
className="h-6 w-6 min-w-0 p-0"
>
<CopyIcon size={14} />
</Button>
</div>
<div className="mt-2">
Pin:{" "}
<span className="font-semibold">
{beautifyString(pinName)}
</span>
</div>
</>
)}
</div>
</div>
<div className="flex-1 overflow-hidden">
<ScrollArea className="h-full">
<div className="my-4">
{dataArray.length > 0 ? (
{shouldGroupExecutions ? (
<div className="space-y-4">
{groupedExecutions.map((execution) => (
<div
key={execution.execId}
className="rounded-3xl border border-slate-200 bg-white p-4 shadow-sm"
>
<div className="flex items-center gap-2">
<Text variant="body" className="text-slate-600">
Execution ID:
</Text>
<Text
variant="body-medium"
className="rounded-full border border-gray-300 bg-gray-50 px-2 py-1 font-mono text-xs"
>
{execution.execId}
</Text>
</div>
<div className="mt-2 space-y-4">
{execution.outputItems.length > 0 ? (
execution.outputItems.map((item, index) => (
<div
key={item.key}
className="group flex items-start gap-4"
>
<div className="w-full flex-1">
<OutputItem
value={item.value}
metadata={item.metadata}
renderer={item.renderer}
/>
</div>
<div className="flex w-fit gap-3">
<Button
variant="secondary"
className="min-w-0 p-1"
size="icon"
onClick={() =>
handleCopyGroupedItem(
execution.execId,
index,
item,
)
}
aria-label="Copy item"
>
{copiedKey ===
`${execution.execId}-${index}` ? (
<CheckIcon className="size-4 text-green-600" />
) : (
<CopyIcon className="size-4 text-black" />
)}
</Button>
<Button
variant="secondary"
size="icon"
className="min-w-0 p-1"
onClick={() =>
handleDownloadGroupedItem(item)
}
aria-label="Download item"
>
<DownloadIcon className="size-4 text-black" />
</Button>
</div>
</div>
))
) : (
<div className="py-4 text-center text-gray-500">
No data available
</div>
)}
</div>
</div>
))}
</div>
) : dataArray.length > 0 ? (
<div className="space-y-4">
{outputItems.map((item, index) => (
<div key={item.key} className="group relative">

View File

@@ -1,82 +1,70 @@
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
import { globalRegistry } from "@/components/contextual/OutputRenderers";
import { downloadOutputs } from "@/components/contextual/OutputRenderers/utils/download";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { beautifyString } from "@/lib/utils";
import React, { useMemo, useState } from "react";
import { useState } from "react";
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
import {
NodeDataType,
createOutputItems,
getExecutionData,
normalizeToArray,
type OutputItem,
} from "../../helpers";
export type GroupedExecution = {
execId: string;
outputItems: Array<OutputItem>;
};
export const useNodeDataViewer = (
data: any,
pinName: string,
execId: string,
executionResults?: NodeExecutionResult[],
dataType?: NodeDataType,
) => {
const { toast } = useToast();
const [copiedIndex, setCopiedIndex] = useState<number | null>(null);
const [copiedKey, setCopiedKey] = useState<string | null>(null);
// Normalize data to array format
const dataArray = useMemo(() => {
return Array.isArray(data) ? data : [data];
}, [data]);
const dataArray = Array.isArray(data) ? data : [data];
// Prepare items for the enhanced renderer system
const outputItems = useMemo(() => {
if (!dataArray) return [];
const items: Array<{
key: string;
label: string;
value: unknown;
metadata?: OutputMetadata;
renderer: any;
}> = [];
dataArray.forEach((value, index) => {
const metadata: OutputMetadata = {};
// Extract metadata from the value if it's an object
if (
typeof value === "object" &&
value !== null &&
!React.isValidElement(value)
) {
const objValue = value as any;
if (objValue.type) metadata.type = objValue.type;
if (objValue.mimeType) metadata.mimeType = objValue.mimeType;
if (objValue.filename) metadata.filename = objValue.filename;
if (objValue.language) metadata.language = objValue.language;
}
const renderer = globalRegistry.getRenderer(value, metadata);
if (renderer) {
items.push({
key: `item-${index}`,
const outputItems =
!dataArray || dataArray.length === 0
? []
: createOutputItems(dataArray).map((item, index) => ({
...item,
label: index === 0 ? beautifyString(pinName) : "",
value,
metadata,
renderer,
});
} else {
// Fallback to text renderer
const textRenderer = globalRegistry
.getAllRenderers()
.find((r) => r.name === "TextRenderer");
if (textRenderer) {
items.push({
key: `item-${index}`,
label: index === 0 ? beautifyString(pinName) : "",
value:
typeof value === "string"
? value
: JSON.stringify(value, null, 2),
metadata,
renderer: textRenderer,
});
}
}
});
}));
return items;
}, [dataArray, pinName]);
const groupedExecutions =
!executionResults || executionResults.length === 0
? []
: [...executionResults].reverse().map((result) => {
const rawData = getExecutionData(
result,
dataType || "output",
pinName,
);
let dataArray: unknown[];
if (dataType === "input") {
dataArray =
rawData !== undefined && rawData !== null ? [rawData] : [];
} else {
dataArray = normalizeToArray(rawData);
}
const outputItems = createOutputItems(dataArray);
return {
execId: result.node_exec_id,
outputItems,
};
});
const totalGroupedItems = groupedExecutions.reduce(
(total, execution) => total + execution.outputItems.length,
0,
);
const copyExecutionId = () => {
navigator.clipboard.writeText(execId).then(() => {
@@ -122,6 +110,45 @@ export const useNodeDataViewer = (
]);
};
const handleCopyGroupedItem = async (
execId: string,
index: number,
item: OutputItem,
) => {
const copyContent = item.renderer.getCopyContent(item.value, item.metadata);
if (!copyContent) {
return;
}
try {
let text: string;
if (typeof copyContent.data === "string") {
text = copyContent.data;
} else if (copyContent.fallbackText) {
text = copyContent.fallbackText;
} else {
return;
}
await navigator.clipboard.writeText(text);
setCopiedKey(`${execId}-${index}`);
setTimeout(() => setCopiedKey(null), 2000);
} catch (error) {
console.error("Failed to copy:", error);
}
};
const handleDownloadGroupedItem = (item: OutputItem) => {
downloadOutputs([
{
value: item.value,
metadata: item.metadata,
renderer: item.renderer,
},
]);
};
return {
outputItems,
dataArray,
@@ -129,5 +156,10 @@ export const useNodeDataViewer = (
handleCopyItem,
handleDownloadItem,
copiedIndex,
groupedExecutions,
totalGroupedItems,
handleCopyGroupedItem,
handleDownloadGroupedItem,
copiedKey,
};
};

Some files were not shown because too many files have changed in this diff Show More