mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-22 13:38:10 -05:00
Compare commits
4 Commits
add-llm-ma
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da9c4a4adf | ||
|
|
0ca73004e5 | ||
|
|
9a786ed8d9 | ||
|
|
0a435e2ffb |
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -128,7 +128,7 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
test:
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
@@ -258,39 +258,3 @@ jobs:
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
26
AGENTS.md
26
AGENTS.md
@@ -16,32 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
||||
- Format Python code with `poetry run format`.
|
||||
- Format frontend code using `pnpm format`.
|
||||
|
||||
|
||||
## Frontend guidelines:
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
|
||||
@@ -201,7 +201,7 @@ If you get any pushback or hit complex block conditions check the new_blocks gui
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
### Frontend guidelines:
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
@@ -217,14 +217,6 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
@@ -122,24 +122,6 @@ class ConnectionManager:
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
|
||||
"""Broadcast a message to all active websocket connections."""
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.active_connections)
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
|
||||
@@ -176,64 +176,30 @@ async def get_execution_analytics_config(
|
||||
# Return with provider prefix for clarity
|
||||
return f"{provider_name}: {model_name}"
|
||||
|
||||
# Get all models from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
|
||||
# Get the recommended model from the database (configurable via admin UI)
|
||||
recommended_model_slug = await llm_db.get_recommended_model_slug()
|
||||
|
||||
# Build the available models list
|
||||
first_enabled_slug = None
|
||||
for registry_model in llm_registry.iter_dynamic_models():
|
||||
# Only include enabled models in the list
|
||||
if not registry_model.is_enabled:
|
||||
continue
|
||||
|
||||
# Track first enabled model as fallback
|
||||
if first_enabled_slug is None:
|
||||
first_enabled_slug = registry_model.slug
|
||||
|
||||
model_enum = LlmModel(registry_model.slug) # Create enum instance from slug
|
||||
label = generate_model_label(model_enum)
|
||||
# Include all LlmModel values (no more filtering by hardcoded list)
|
||||
recommended_model = LlmModel.GPT4O_MINI.value
|
||||
for model in LlmModel:
|
||||
label = generate_model_label(model)
|
||||
# Add "(Recommended)" suffix to the recommended model
|
||||
if registry_model.slug == recommended_model_slug:
|
||||
if model.value == recommended_model:
|
||||
label += " (Recommended)"
|
||||
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value=registry_model.slug,
|
||||
value=model.value,
|
||||
label=label,
|
||||
provider=registry_model.metadata.provider,
|
||||
provider=model.provider,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort models by provider and name for better UX
|
||||
available_models.sort(key=lambda x: (x.provider, x.label))
|
||||
|
||||
# Handle case where no models are available
|
||||
if not available_models:
|
||||
logger.warning(
|
||||
"No enabled LLM models found in registry. "
|
||||
"Ensure models are configured and enabled in the LLM Registry."
|
||||
)
|
||||
# Provide a placeholder entry so admins see meaningful feedback
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value="",
|
||||
label="No models available - configure in LLM Registry",
|
||||
provider="none",
|
||||
)
|
||||
)
|
||||
|
||||
# Use the DB recommended model, or fallback to first enabled model
|
||||
final_recommended = recommended_model_slug or first_enabled_slug or ""
|
||||
|
||||
return ExecutionAnalyticsConfig(
|
||||
available_models=available_models,
|
||||
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||
default_user_prompt=DEFAULT_USER_PROMPT,
|
||||
recommended_model=final_recommended,
|
||||
recommended_model=recommended_model,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,595 +0,0 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
tags=["llm", "admin"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
async def _refresh_runtime_state() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||
logger.info("Refreshing LLM registry runtime state...")
|
||||
try:
|
||||
# Refresh registry from database
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated model options
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
logger.info("Cleared all block schema caches")
|
||||
|
||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||
try:
|
||||
from backend.api.features.v1 import _get_cached_blocks
|
||||
|
||||
_get_cached_blocks.cache_clear()
|
||||
logger.info("Cleared /blocks endpoint cache")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear /blocks cache: %s", e)
|
||||
|
||||
# Clear the v2 builder caches (if they exist)
|
||||
try:
|
||||
from backend.api.features.builder import db as builder_db
|
||||
|
||||
if hasattr(builder_db, "_get_all_providers"):
|
||||
builder_db._get_all_providers.cache_clear()
|
||||
logger.info("Cleared v2 builder providers cache")
|
||||
if hasattr(builder_db, "_build_cached_search_results"):
|
||||
builder_db._build_cached_search_results.cache_clear()
|
||||
logger.info("Cleared v2 builder search results cache")
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||
|
||||
# Notify all executor services to refresh their registry cache
|
||||
from backend.data.llm_registry import publish_registry_refresh_notification
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
logger.info("Published registry refresh notification")
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"LLM runtime state refresh failed; caches may be stale: %s", exc
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="List LLM providers",
|
||||
response_model=llm_model.LlmProvidersResponse,
|
||||
)
|
||||
async def list_llm_providers(include_models: bool = True):
|
||||
providers = await llm_db.list_providers(include_models=include_models)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/providers",
|
||||
summary="Create LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
|
||||
provider = await llm_db.upsert_provider(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/providers/{provider_id}",
|
||||
summary="Update LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def update_llm_provider(
|
||||
provider_id: str,
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
):
|
||||
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/providers/{provider_id}",
|
||||
summary="Delete LLM provider",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_provider(provider_id: str):
|
||||
"""
|
||||
Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
Delete all models from the provider first before deleting the provider.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_provider(provider_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted LLM provider '%s'", provider_id)
|
||||
return {"success": True, "message": "Provider deleted successfully"}
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List LLM models",
|
||||
response_model=llm_model.LlmModelsResponse,
|
||||
)
|
||||
async def list_llm_models(
|
||||
provider_id: str | None = fastapi.Query(default=None),
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=50, ge=1, le=100, description="Number of models per page"
|
||||
),
|
||||
):
|
||||
return await llm_db.list_models(
|
||||
provider_id=provider_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
summary="Create LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
|
||||
model = await llm_db.create_model(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}",
|
||||
summary="Update LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def update_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
):
|
||||
model = await llm_db.update_model(model_id=model_id, request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}/toggle",
|
||||
summary="Toggle LLM model availability",
|
||||
response_model=llm_model.ToggleLlmModelResponse,
|
||||
)
|
||||
async def toggle_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.ToggleLlmModelRequest,
|
||||
):
|
||||
"""
|
||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||
|
||||
If disabling a model and `migrate_to_slug` is provided, all workflows using
|
||||
this model will be migrated to the specified replacement model before disabling.
|
||||
A migration record is created which can be reverted later using the revert endpoint.
|
||||
|
||||
Optional fields:
|
||||
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
|
||||
- `custom_credit_cost`: Custom pricing override for billing during migration
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.toggle_model(
|
||||
model_id=model_id,
|
||||
is_enabled=request.is_enabled,
|
||||
migrate_to_slug=request.migrate_to_slug,
|
||||
migration_reason=request.migration_reason,
|
||||
custom_credit_cost=request.custom_credit_cost,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
if result.nodes_migrated > 0:
|
||||
logger.info(
|
||||
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
|
||||
result.model.slug,
|
||||
"enabled" if request.is_enabled else "disabled",
|
||||
result.nodes_migrated,
|
||||
result.migrated_to_slug,
|
||||
result.migration_id,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Model toggle validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to toggle model availability",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models/{model_id}/usage",
|
||||
summary="Get model usage count",
|
||||
response_model=llm_model.LlmModelUsageResponse,
|
||||
)
|
||||
async def get_llm_model_usage(model_id: str):
|
||||
"""Get the number of workflow nodes using this model."""
|
||||
try:
|
||||
return await llm_db.get_model_usage(model_id=model_id)
|
||||
except ValueError as exc:
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get model usage %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get model usage",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/models/{model_id}",
|
||||
summary="Delete LLM model and migrate workflows",
|
||||
response_model=llm_model.DeleteLlmModelResponse,
|
||||
)
|
||||
async def delete_llm_model(
|
||||
model_id: str,
|
||||
replacement_model_slug: str | None = fastapi.Query(
|
||||
default=None,
|
||||
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Delete a model and optionally migrate workflows using it to a replacement model.
|
||||
|
||||
If no workflows are using this model, it can be deleted without providing a
|
||||
replacement. If workflows exist, replacement_model_slug is required.
|
||||
|
||||
This endpoint:
|
||||
1. Counts how many workflow nodes use the model being deleted
|
||||
2. If nodes exist, validates the replacement model and migrates them
|
||||
3. Deletes the model record
|
||||
4. Refreshes all caches and notifies executors
|
||||
|
||||
Example: DELETE /admin/llm/models/{id}?replacement_model_slug=gpt-4o
|
||||
Example (no usage): DELETE /admin/llm/models/{id}
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.delete_model(
|
||||
model_id=model_id, replacement_model_slug=replacement_model_slug
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Deleted model '%s' and migrated %d nodes to '%s'",
|
||||
result.deleted_model_slug,
|
||||
result.nodes_migrated,
|
||||
result.replacement_model_slug,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
# Validation errors (model not found, replacement invalid, etc.)
|
||||
logger.warning("Model deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete model and migrate workflows",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Migration Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations",
|
||||
summary="List model migrations",
|
||||
response_model=llm_model.LlmMigrationsResponse,
|
||||
)
|
||||
async def list_llm_migrations(
|
||||
include_reverted: bool = fastapi.Query(
|
||||
default=False, description="Include reverted migrations in the list"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all model migrations.
|
||||
|
||||
Migrations are created when disabling a model with the migrate_to_slug option.
|
||||
They can be reverted to restore the original model configuration.
|
||||
"""
|
||||
try:
|
||||
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
|
||||
return llm_model.LlmMigrationsResponse(migrations=migrations)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list migrations: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list migrations",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations/{migration_id}",
|
||||
summary="Get migration details",
|
||||
response_model=llm_model.LlmModelMigration,
|
||||
)
|
||||
async def get_llm_migration(migration_id: str):
|
||||
"""Get details of a specific migration."""
|
||||
try:
|
||||
migration = await llm_db.get_migration(migration_id)
|
||||
if not migration:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Migration '{migration_id}' not found"
|
||||
)
|
||||
return migration
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get migration",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/migrations/{migration_id}/revert",
|
||||
summary="Revert a model migration",
|
||||
response_model=llm_model.RevertMigrationResponse,
|
||||
)
|
||||
async def revert_llm_migration(
|
||||
migration_id: str,
|
||||
request: llm_model.RevertMigrationRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Revert a model migration, restoring affected workflows to their original model.
|
||||
|
||||
This only reverts the specific nodes that were part of the migration.
|
||||
The source model must exist for the revert to succeed.
|
||||
|
||||
Options:
|
||||
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
|
||||
|
||||
Response includes:
|
||||
- `nodes_reverted`: Number of nodes successfully reverted
|
||||
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
|
||||
- `source_model_re_enabled`: Whether the source model was re-enabled
|
||||
|
||||
Requirements:
|
||||
- Migration must not already be reverted
|
||||
- Source model must exist
|
||||
"""
|
||||
try:
|
||||
re_enable = request.re_enable_source_model if request else True
|
||||
result = await llm_db.revert_migration(
|
||||
migration_id,
|
||||
re_enable_source_model=re_enable,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
|
||||
"(%d already changed, source re-enabled=%s)",
|
||||
migration_id,
|
||||
result.nodes_reverted,
|
||||
result.target_model_slug,
|
||||
result.source_model_slug,
|
||||
result.nodes_already_changed,
|
||||
result.source_model_re_enabled,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Migration revert validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to revert migration",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creator Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators",
|
||||
summary="List model creators",
|
||||
response_model=llm_model.LlmCreatorsResponse,
|
||||
)
|
||||
async def list_llm_creators():
|
||||
"""
|
||||
List all model creators.
|
||||
|
||||
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
|
||||
This is distinct from providers who host/serve the models (e.g., OpenRouter).
|
||||
"""
|
||||
try:
|
||||
creators = await llm_db.list_creators()
|
||||
return llm_model.LlmCreatorsResponse(creators=creators)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list creators: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list creators",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators/{creator_id}",
|
||||
summary="Get creator details",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def get_llm_creator(creator_id: str):
|
||||
"""Get details of a specific model creator."""
|
||||
try:
|
||||
creator = await llm_db.get_creator(creator_id)
|
||||
if not creator:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Creator '{creator_id}' not found"
|
||||
)
|
||||
return creator
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/creators",
|
||||
summary="Create model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
|
||||
"""
|
||||
Create a new model creator.
|
||||
|
||||
A creator represents an organization that creates/trains AI models,
|
||||
such as OpenAI, Anthropic, Meta, or Google.
|
||||
"""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to create creator: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to create creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/creators/{creator_id}",
|
||||
summary="Update model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def update_llm_creator(
|
||||
creator_id: str,
|
||||
request: llm_model.UpsertLlmCreatorRequest,
|
||||
):
|
||||
"""Update an existing model creator."""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to update creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/creators/{creator_id}",
|
||||
summary="Delete model creator",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_creator(creator_id: str):
|
||||
"""
|
||||
Delete a model creator.
|
||||
|
||||
This will remove the creator association from all models that reference it
|
||||
(sets creatorId to NULL), but will not delete the models themselves.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_creator(creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted model creator '%s'", creator_id)
|
||||
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
|
||||
except ValueError as exc:
|
||||
logger.warning("Creator deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete creator",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recommended Model Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/recommended-model",
|
||||
summary="Get recommended model",
|
||||
response_model=llm_model.RecommendedModelResponse,
|
||||
)
|
||||
async def get_recommended_model():
|
||||
"""
|
||||
Get the currently recommended LLM model.
|
||||
|
||||
The recommended model is shown to users as the default/suggested option
|
||||
in model selection dropdowns.
|
||||
"""
|
||||
try:
|
||||
model = await llm_db.get_recommended_model()
|
||||
return llm_model.RecommendedModelResponse(
|
||||
model=model,
|
||||
slug=model.slug if model else None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get recommended model",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/recommended-model",
|
||||
summary="Set recommended model",
|
||||
response_model=llm_model.SetRecommendedModelResponse,
|
||||
)
|
||||
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
|
||||
"""
|
||||
Set a model as the recommended model.
|
||||
|
||||
This clears the recommended flag from any other model and sets it on
|
||||
the specified model. The model must be enabled to be set as recommended.
|
||||
|
||||
The recommended model is displayed to users as the default/suggested
|
||||
option in model selection dropdowns throughout the platform.
|
||||
"""
|
||||
try:
|
||||
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Set recommended model to '%s' (previous: %s)",
|
||||
model.slug,
|
||||
previous_slug or "none",
|
||||
)
|
||||
return llm_model.SetRecommendedModelResponse(
|
||||
model=model,
|
||||
previous_recommended_slug=previous_slug,
|
||||
message=f"Model '{model.display_name}' is now the recommended model",
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.warning("Set recommended model validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to set recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to set recommended model",
|
||||
) from exc
|
||||
@@ -1,491 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.admin.llm_routes as llm_routes
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
from backend.util.models import Pagination
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(llm_routes.router, prefix="/admin/llm")
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_list_llm_providers_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM providers"""
|
||||
# Mock the database function
|
||||
mock_providers = [
|
||||
{
|
||||
"id": "provider-1",
|
||||
"name": "openai",
|
||||
"display_name": "OpenAI",
|
||||
"description": "OpenAI LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
{
|
||||
"id": "provider-2",
|
||||
"name": "anthropic",
|
||||
"display_name": "Anthropic",
|
||||
"description": "Anthropic LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_providers",
|
||||
new=AsyncMock(return_value=mock_providers),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["providers"]) == 2
|
||||
assert response_data["providers"][0]["name"] == "openai"
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_providers_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_list_llm_models_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM models with pagination"""
|
||||
# Mock the database function - now returns LlmModelsResponse
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=True,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
id="cost-1",
|
||||
credit_cost=10,
|
||||
credential_provider="openai",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_response = llm_model.LlmModelsResponse(
|
||||
models=[mock_model],
|
||||
pagination=Pagination(
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=50,
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_models",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["models"]) == 1
|
||||
assert response_data["models"][0]["slug"] == "gpt-4o"
|
||||
assert response_data["pagination"]["total_items"] == 1
|
||||
assert response_data["pagination"]["page_size"] == 50
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_models_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_provider_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM provider"""
|
||||
mock_provider = {
|
||||
"id": "new-provider-id",
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
|
||||
new=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/providers", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["name"] == "groq"
|
||||
assert response_data["display_name"] == "Groq"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_provider_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM model"""
|
||||
mock_model = {
|
||||
"id": "new-model-id",
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-id",
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.create_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/models", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["slug"] == "gpt-4.1-mini"
|
||||
assert response_data["is_enabled"] is True
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_update_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful update of LLM model"""
|
||||
mock_model = {
|
||||
"id": "model-1",
|
||||
"slug": "gpt-4o",
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-1",
|
||||
"credit_cost": 15,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.update_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["display_name"] == "GPT-4o Updated"
|
||||
assert response_data["context_window"] == 256000
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"update_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_toggle_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful toggling of LLM model enabled status"""
|
||||
# Create a proper mock model object
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=False,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[],
|
||||
)
|
||||
|
||||
# Create a proper ToggleLlmModelResponse
|
||||
mock_response = llm_model.ToggleLlmModelResponse(
|
||||
model=mock_model,
|
||||
nodes_migrated=0,
|
||||
migrated_to_slug=None,
|
||||
migration_id=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {"is_enabled": False}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["model"]["is_enabled"] is False
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"toggle_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful deletion of LLM model with migration"""
|
||||
# Create a proper DeleteLlmModelResponse
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="gpt-3.5-turbo",
|
||||
deleted_model_display_name="GPT-3.5 Turbo",
|
||||
replacement_model_slug="gpt-4o-mini",
|
||||
nodes_migrated=42,
|
||||
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
|
||||
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
|
||||
assert response_data["nodes_migrated"] == 42
|
||||
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"delete_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_validation_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails with proper error when validation fails"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Replacement model 'invalid' not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_with_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails when nodes exist but no replacement is provided"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(
|
||||
side_effect=ValueError(
|
||||
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
|
||||
"Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "workflow node(s) are using it" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_no_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="unused-model",
|
||||
deleted_model_display_name="Unused Model",
|
||||
replacement_model_slug=None,
|
||||
nodes_migrated=0,
|
||||
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "unused-model"
|
||||
assert response_data["nodes_migrated"] == 0
|
||||
assert response_data["replacement_model_slug"] is None
|
||||
mock_refresh.assert_called_once()
|
||||
@@ -15,7 +15,6 @@ from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.llm_registry import get_all_model_slugs_for_validation
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
@@ -32,14 +31,7 @@ from .model import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_llm_models() -> list[str]:
|
||||
"""Get LLM model names for search matching from the registry."""
|
||||
return [
|
||||
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
|
||||
]
|
||||
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
|
||||
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
@@ -504,8 +496,8 @@ async def _get_static_counts():
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models from registry
|
||||
if any(query in name for name in _get_llm_models()):
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -290,11 +290,6 @@ async def _cache_session(session: ChatSession) -> None:
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
|
||||
@@ -172,12 +172,12 @@ async def get_session(
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
@@ -222,8 +222,6 @@ async def stream_chat_post(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
@@ -232,26 +230,7 @@ async def stream_chat_post(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -296,8 +275,6 @@ async def stream_chat_get(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
@@ -305,26 +282,7 @@ async def stream_chat_get(
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@@ -1,20 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import get_client, propagate_attributes
|
||||
from langfuse.openai import openai # type: ignore
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
APIStatusError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
)
|
||||
from openai import APIConnectionError, APIError, APIStatusError, RateLimitError
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.understanding import (
|
||||
@@ -29,7 +21,6 @@ from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
cache_chat_session,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
@@ -305,10 +296,6 @@ async def stream_chat_completion(
|
||||
content="",
|
||||
)
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_saved_assistant_message = False
|
||||
has_appended_streaming_message = False
|
||||
last_cache_time = 0.0
|
||||
last_cache_content_len = 0
|
||||
|
||||
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
||||
has_yielded_end = False
|
||||
@@ -345,23 +332,6 @@ async def stream_chat_completion(
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += delta
|
||||
has_received_text = True
|
||||
if not has_appended_streaming_message:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_streaming_message = True
|
||||
current_time = time.monotonic()
|
||||
content_len = len(assistant_response.content)
|
||||
if (
|
||||
current_time - last_cache_time >= 1.0
|
||||
and content_len > last_cache_content_len
|
||||
):
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to cache partial session {session.session_id}: {e}"
|
||||
)
|
||||
last_cache_time = current_time
|
||||
last_cache_content_len = content_len
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextEnd):
|
||||
# Emit text-end after text completes
|
||||
@@ -420,42 +390,10 @@ async def stream_chat_completion(
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
|
||||
# Save assistant message before yielding finish to ensure it's persisted
|
||||
# even if client disconnects immediately after receiving StreamFinish
|
||||
if not has_saved_assistant_message:
|
||||
messages_to_save_early: list[ChatMessage] = []
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = (
|
||||
accumulated_tool_calls
|
||||
)
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content
|
||||
or assistant_response.tool_calls
|
||||
):
|
||||
messages_to_save_early.append(assistant_response)
|
||||
messages_to_save_early.extend(tool_response_messages)
|
||||
|
||||
if messages_to_save_early:
|
||||
session.messages.extend(messages_to_save_early)
|
||||
logger.info(
|
||||
f"Saving assistant message before StreamFinish: "
|
||||
f"content_len={len(assistant_response.content or '')}, "
|
||||
f"tool_calls={len(assistant_response.tool_calls or [])}, "
|
||||
f"tool_responses={len(tool_response_messages)}"
|
||||
)
|
||||
if (
|
||||
messages_to_save_early
|
||||
or has_appended_streaming_message
|
||||
):
|
||||
await upsert_chat_session(session)
|
||||
has_saved_assistant_message = True
|
||||
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
@@ -475,27 +413,6 @@ async def stream_chat_completion(
|
||||
langfuse.update_current_trace(output=str(tool_response_messages))
|
||||
langfuse.update_current_span(output=str(tool_response_messages))
|
||||
|
||||
except CancelledError:
|
||||
if not has_saved_assistant_message:
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content:
|
||||
assistant_response.content = (
|
||||
f"{assistant_response.content}\n\n[interrupted]"
|
||||
)
|
||||
else:
|
||||
assistant_response.content = "[interrupted]"
|
||||
if not has_appended_streaming_message:
|
||||
session.messages.append(assistant_response)
|
||||
if tool_response_messages:
|
||||
session.messages.extend(tool_response_messages)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to save interrupted session {session.session_id}: {e}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
|
||||
@@ -517,19 +434,14 @@ async def stream_chat_completion(
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
):
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
if not has_saved_assistant_message:
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
@@ -560,49 +472,38 @@ async def stream_chat_completion(
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
# Only save if we haven't already saved when StreamFinish was received
|
||||
if not has_saved_assistant_message:
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
):
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
else:
|
||||
logger.info(
|
||||
"Assistant message already saved when StreamFinish was received, "
|
||||
"skipping duplicate save"
|
||||
)
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
@@ -644,12 +545,6 @@ def _is_retryable_error(error: Exception) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_region_blocked_error(error: Exception) -> bool:
|
||||
if isinstance(error, PermissionDeniedError):
|
||||
return "not available in your region" in str(error).lower()
|
||||
return "not available in your region" in str(error).lower()
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
@@ -842,18 +737,7 @@ async def _stream_chat_chunks(
|
||||
f"Error in stream (not retrying): {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_code = None
|
||||
error_text = str(e)
|
||||
if _is_region_blocked_error(e):
|
||||
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
||||
error_text = (
|
||||
"This model is not available in your region. "
|
||||
"Please connect via VPN and try again."
|
||||
)
|
||||
error_response = StreamError(
|
||||
errorText=error_text,
|
||||
code=error_code,
|
||||
)
|
||||
error_response = StreamError(errorText=str(e))
|
||||
yield error_response
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
apply_agent_patch,
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
json_to_graph,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .fixer import apply_all_fixes
|
||||
from .utils import get_blocks_info
|
||||
from .validator import validate_agent
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"apply_agent_patch",
|
||||
"save_agent_to_library",
|
||||
"get_agent_as_json",
|
||||
# Fixer
|
||||
"apply_all_fixes",
|
||||
# Validator
|
||||
"validate_agent",
|
||||
# Utils
|
||||
"get_blocks_info",
|
||||
"json_to_graph",
|
||||
# Exceptions
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
# Service
|
||||
"is_external_service_configured",
|
||||
"check_external_service_health",
|
||||
]
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
"""OpenRouter client configuration for agent generation."""
|
||||
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
|
||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
||||
|
||||
# OpenRouter client (OpenAI-compatible API)
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_client() -> AsyncOpenAI:
|
||||
"""Get or create the OpenRouter client."""
|
||||
global _client
|
||||
if _client is None:
|
||||
if not OPENROUTER_API_KEY:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
_client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
)
|
||||
return _client
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -9,13 +7,35 @@ from typing import Any
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
||||
from .utils import get_block_summaries, parse_json_from_llm
|
||||
from .service import (
|
||||
decompose_goal_external,
|
||||
generate_agent_external,
|
||||
generate_agent_patch_external,
|
||||
is_external_service_configured,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _check_service_configured() -> None:
|
||||
"""Check if the external Agent Generator service is configured.
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
raise AgentGeneratorNotConfiguredError(
|
||||
"Agent Generator service is not configured. "
|
||||
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||
)
|
||||
|
||||
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
@@ -28,40 +48,13 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
full_description = description
|
||||
if context:
|
||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": full_description},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for decomposition")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decomposing goal: {e}")
|
||||
return None
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
return await decompose_goal_external(description, context)
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -72,31 +65,14 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for agent generation")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
||||
return None
|
||||
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(instructions)
|
||||
if result:
|
||||
# Ensure required fields
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
@@ -104,12 +80,7 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent: {e}")
|
||||
return None
|
||||
return result
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
@@ -284,108 +255,23 @@ async def get_agent_as_json(
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate a patch to update an existing agent.
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Generating the patch
|
||||
- Applying the patch
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Patch dict or clarifying questions, or None on error
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = PATCH_PROMPT.format(
|
||||
current_agent=json.dumps(current_agent, indent=2),
|
||||
block_summaries=get_block_summaries(),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": update_request},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for patch generation")
|
||||
return None
|
||||
|
||||
return parse_json_from_llm(content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating patch: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def apply_agent_patch(
|
||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Apply a patch to an existing agent.
|
||||
|
||||
Args:
|
||||
current_agent: Current agent JSON
|
||||
patch: Patch dict with operations
|
||||
|
||||
Returns:
|
||||
Updated agent JSON
|
||||
"""
|
||||
agent = copy.deepcopy(current_agent)
|
||||
patches = patch.get("patches", [])
|
||||
|
||||
for p in patches:
|
||||
patch_type = p.get("type")
|
||||
|
||||
if patch_type == "modify":
|
||||
node_id = p.get("node_id")
|
||||
changes = p.get("changes", {})
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
_deep_update(node, changes)
|
||||
logger.debug(f"Modified node {node_id}")
|
||||
break
|
||||
|
||||
elif patch_type == "add":
|
||||
new_nodes = p.get("new_nodes", [])
|
||||
new_links = p.get("new_links", [])
|
||||
|
||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
||||
agent["links"] = agent.get("links", []) + new_links
|
||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
||||
|
||||
elif patch_type == "remove":
|
||||
node_ids_to_remove = set(p.get("node_ids", []))
|
||||
link_ids_to_remove = set(p.get("link_ids", []))
|
||||
|
||||
# Remove nodes
|
||||
agent["nodes"] = [
|
||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
# Remove links (both explicit and those referencing removed nodes)
|
||||
agent["links"] = [
|
||||
link
|
||||
for link in agent.get("links", [])
|
||||
if link["id"] not in link_ids_to_remove
|
||||
and link["source_id"] not in node_ids_to_remove
|
||||
and link["sink_id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _deep_update(target: dict, source: dict) -> None:
|
||||
"""Recursively update a dict with another dict."""
|
||||
for key, value in source.items():
|
||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
||||
_deep_update(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(update_request, current_agent)
|
||||
|
||||
@@ -1,606 +0,0 @@
|
||||
"""Agent fixer - Fixes common LLM generation errors."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
ADDTODICTIONARY_BLOCK_ID,
|
||||
ADDTOLIST_BLOCK_ID,
|
||||
CODE_EXECUTION_BLOCK_ID,
|
||||
CONDITION_BLOCK_ID,
|
||||
CREATEDICT_BLOCK_ID,
|
||||
CREATELIST_BLOCK_ID,
|
||||
DATA_SAMPLING_BLOCK_ID,
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
||||
GET_CURRENT_DATE_BLOCK_ID,
|
||||
STORE_VALUE_BLOCK_ID,
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
get_blocks_info,
|
||||
is_valid_uuid,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix invalid UUIDs in agent and link IDs."""
|
||||
# Fix agent ID
|
||||
if not is_valid_uuid(agent.get("id", "")):
|
||||
agent["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
||||
|
||||
# Fix node IDs
|
||||
id_mapping = {} # Old ID -> New ID
|
||||
for node in agent.get("nodes", []):
|
||||
if not is_valid_uuid(node.get("id", "")):
|
||||
old_id = node.get("id", "")
|
||||
new_id = str(uuid.uuid4())
|
||||
id_mapping[old_id] = new_id
|
||||
node["id"] = new_id
|
||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
||||
|
||||
# Fix link IDs and update references
|
||||
for link in agent.get("links", []):
|
||||
if not is_valid_uuid(link.get("id", "")):
|
||||
link["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed link ID: {link['id']}")
|
||||
|
||||
# Update source/sink IDs if they were remapped
|
||||
if link.get("source_id") in id_mapping:
|
||||
link["source_id"] = id_mapping[link["source_id"]]
|
||||
if link.get("sink_id") in id_mapping:
|
||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix single curly braces to double in template blocks."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
input_data = node.get("input_default", {})
|
||||
for key in ("prompt", "format"):
|
||||
if key in input_data and isinstance(input_data[key], str):
|
||||
original = input_data[key]
|
||||
# Fix simple variable references: {var} -> {{var}}
|
||||
fixed = re.sub(
|
||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
||||
r"{{\1}}",
|
||||
original,
|
||||
)
|
||||
if fixed != original:
|
||||
input_data[key] = fixed
|
||||
logger.debug(f"Fixed curly braces in {key}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Find all ConditionBlock nodes
|
||||
condition_node_ids = {
|
||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
||||
}
|
||||
|
||||
if not condition_node_ids:
|
||||
return agent
|
||||
|
||||
new_nodes = []
|
||||
new_links = []
|
||||
processed_conditions = set()
|
||||
|
||||
for link in links:
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
# Check if this link goes to a ConditionBlock's value2
|
||||
if sink_id in condition_node_ids and sink_name == "value2":
|
||||
source_node = next(
|
||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
||||
)
|
||||
|
||||
# Skip if source is already a StoreValueBlock
|
||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
||||
continue
|
||||
|
||||
# Skip if we already processed this condition
|
||||
if sink_id in processed_conditions:
|
||||
continue
|
||||
|
||||
processed_conditions.add(sink_id)
|
||||
|
||||
# Create StoreValueBlock
|
||||
store_node_id = str(uuid.uuid4())
|
||||
store_node = {
|
||||
"id": store_node_id,
|
||||
"block_id": STORE_VALUE_BLOCK_ID,
|
||||
"input_default": {"data": None},
|
||||
"metadata": {"position": {"x": 0, "y": -100}},
|
||||
}
|
||||
new_nodes.append(store_node)
|
||||
|
||||
# Create link: original source -> StoreValueBlock
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": store_node_id,
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Update original link: StoreValueBlock -> ConditionBlock
|
||||
link["source_id"] = store_node_id
|
||||
link["source_name"] = "output"
|
||||
|
||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
||||
|
||||
if new_nodes:
|
||||
agent["nodes"] = nodes + new_nodes
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
||||
|
||||
When an AddToList block is found:
|
||||
1. Checks if there's a CreateListBlock before it
|
||||
2. Removes CreateListBlock if linked directly to AddToList
|
||||
3. Adds an empty AddToList block before the original
|
||||
4. Ensures the original has a self-referencing link
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
new_nodes = []
|
||||
original_addtolist_ids = set()
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
# First pass: identify CreateListBlock nodes to remove
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
||||
|
||||
# Second pass: process AddToList blocks
|
||||
filtered_nodes = []
|
||||
for node in nodes:
|
||||
if node.get("id") in nodes_to_remove:
|
||||
continue
|
||||
|
||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
||||
original_addtolist_ids.add(node.get("id"))
|
||||
node_id = node.get("id")
|
||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
||||
|
||||
# Check if already has prerequisite
|
||||
has_prereq = any(
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_name") == "updated_list"
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_prereq:
|
||||
# Remove links to "list" input (except self-reference)
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_id") != node_id
|
||||
and link not in links_to_remove
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Create prerequisite AddToList block
|
||||
prereq_id = str(uuid.uuid4())
|
||||
prereq_node = {
|
||||
"id": prereq_id,
|
||||
"block_id": ADDTOLIST_BLOCK_ID,
|
||||
"input_default": {"list": [], "entry": None, "entries": []},
|
||||
"metadata": {
|
||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
||||
},
|
||||
}
|
||||
new_nodes.append(prereq_node)
|
||||
|
||||
# Link prerequisite to original
|
||||
links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": prereq_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
||||
|
||||
filtered_nodes.append(node)
|
||||
|
||||
# Remove marked links
|
||||
filtered_links = [link for link in links if link not in links_to_remove]
|
||||
|
||||
# Add self-referencing links for original AddToList blocks
|
||||
for node in filtered_nodes + new_nodes:
|
||||
if (
|
||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
and node.get("id") in original_addtolist_ids
|
||||
):
|
||||
node_id = node.get("id")
|
||||
has_self_ref = any(
|
||||
link["source_id"] == node_id
|
||||
and link["sink_id"] == node_id
|
||||
and link["source_name"] == "updated_list"
|
||||
and link["sink_name"] == "list"
|
||||
for link in filtered_links
|
||||
)
|
||||
if not has_self_ref:
|
||||
filtered_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": node_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
||||
|
||||
agent["nodes"] = filtered_nodes + new_nodes
|
||||
agent["links"] = filtered_links
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
||||
|
||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
return agent
|
||||
|
||||
|
||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
if (
|
||||
source_node
|
||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
||||
and link.get("source_name") == "response"
|
||||
):
|
||||
link["source_name"] = "stdout_logs"
|
||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
links_to_remove = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Remove links to sample_size
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "sample_size"
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Set default
|
||||
input_default["sample_size"] = 1
|
||||
node["input_default"] = input_default
|
||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
||||
|
||||
if links_to_remove:
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
for link in links:
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
||||
|
||||
source_x = source_pos.get("x", 0)
|
||||
sink_x = sink_pos.get("x", 0)
|
||||
|
||||
if abs(sink_x - source_x) < 800:
|
||||
new_x = source_x + 800
|
||||
if "metadata" not in sink_node:
|
||||
sink_node["metadata"] = {}
|
||||
if "position" not in sink_node["metadata"]:
|
||||
sink_node["metadata"]["position"] = {}
|
||||
sink_node["metadata"]["position"]["x"] = new_x
|
||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
if "offset" in input_default:
|
||||
offset = input_default["offset"]
|
||||
if isinstance(offset, (int, float)) and offset < 0:
|
||||
input_default["offset"] = abs(offset)
|
||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_ai_model_parameter(
|
||||
agent: dict[str, Any],
|
||||
blocks_info: list[dict[str, Any]],
|
||||
default_model: str = "gpt-4o",
|
||||
) -> dict[str, Any]:
|
||||
"""Add default model parameter to AI blocks if missing."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
# Check if block has AI category
|
||||
categories = block.get("categories", [])
|
||||
is_ai_block = any(
|
||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
||||
)
|
||||
|
||||
if is_ai_block:
|
||||
input_default = node.get("input_default", {})
|
||||
if "model" not in input_default:
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_link_static_properties(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix is_static property based on source block's staticOutput."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
if not source_block:
|
||||
continue
|
||||
|
||||
static_output = source_block.get("staticOutput", False)
|
||||
if link.get("is_static") != static_output:
|
||||
link["is_static"] = static_output
|
||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_type_mismatch(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
def get_property_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
type_mapping = {
|
||||
"string": "string",
|
||||
"text": "string",
|
||||
"integer": "number",
|
||||
"number": "number",
|
||||
"float": "number",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"array": "list",
|
||||
"list": "list",
|
||||
"object": "dictionary",
|
||||
"dict": "dictionary",
|
||||
"dictionary": "dictionary",
|
||||
}
|
||||
|
||||
new_links = []
|
||||
nodes_to_add = []
|
||||
|
||||
for link in links:
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
# Insert type converter
|
||||
converter_id = str(uuid.uuid4())
|
||||
target_type = type_mapping.get(sink_type, sink_type)
|
||||
|
||||
converter_node = {
|
||||
"id": converter_id,
|
||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
"input_default": {"type": target_type},
|
||||
"metadata": {"position": {"x": 0, "y": 100}},
|
||||
}
|
||||
nodes_to_add.append(converter_node)
|
||||
|
||||
# source -> converter
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": converter_id,
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# converter -> sink
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": converter_id,
|
||||
"source_name": "value",
|
||||
"sink_id": link["sink_id"],
|
||||
"sink_name": link["sink_name"],
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
||||
else:
|
||||
new_links.append(link)
|
||||
|
||||
if nodes_to_add:
|
||||
agent["nodes"] = nodes + nodes_to_add
|
||||
agent["links"] = new_links
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def apply_all_fixes(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Apply all fixes to an agent JSON.
|
||||
|
||||
Args:
|
||||
agent: Agent JSON dict
|
||||
blocks_info: Optional list of block info dicts for advanced fixes
|
||||
|
||||
Returns:
|
||||
Fixed agent JSON
|
||||
"""
|
||||
# Basic fixes (no block info needed)
|
||||
agent = fix_agent_ids(agent)
|
||||
agent = fix_double_curly_braces(agent)
|
||||
agent = fix_storevalue_before_condition(agent)
|
||||
agent = fix_addtolist_blocks(agent)
|
||||
agent = fix_addtodictionary_blocks(agent)
|
||||
agent = fix_code_execution_output(agent)
|
||||
agent = fix_data_sampling_sample_size(agent)
|
||||
agent = fix_node_x_coordinates(agent)
|
||||
agent = fix_getcurrentdate_offset(agent)
|
||||
|
||||
# Advanced fixes (require block info)
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
||||
agent = fix_link_static_properties(agent, blocks_info)
|
||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
||||
|
||||
return agent
|
||||
@@ -1,225 +0,0 @@
|
||||
"""Prompt templates for agent generation."""
|
||||
|
||||
DECOMPOSITION_PROMPT = """
|
||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
||||
|
||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
||||
|
||||
---
|
||||
|
||||
FIRST: Analyze the user's goal and determine:
|
||||
1) Design-time configuration (fixed settings that won't change per run)
|
||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
||||
|
||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
||||
- DO NOT ask for the actual value
|
||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
||||
|
||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
||||
- Business rules that must be hard-coded
|
||||
|
||||
IMPORTANT CLARIFICATIONS POLICY:
|
||||
- Ask no more than five essential questions
|
||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
||||
- Do not ask for API keys or credentials; the platform handles those directly
|
||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
||||
|
||||
---
|
||||
|
||||
GUIDELINES:
|
||||
1. List each step as a numbered item
|
||||
2. Describe the action clearly and specify inputs/outputs
|
||||
3. Ensure steps are in logical, sequential order
|
||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
||||
5. Help the user reach their goal efficiently
|
||||
|
||||
---
|
||||
|
||||
RULES:
|
||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
||||
2. USE ONLY THE BLOCKS PROVIDED
|
||||
3. ALL required_input fields must be provided
|
||||
4. Data types of linked properties must match
|
||||
5. Write expert-level prompts for AI-related blocks
|
||||
|
||||
---
|
||||
|
||||
CRITICAL BLOCK RESTRICTIONS:
|
||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
||||
|
||||
---
|
||||
|
||||
OUTPUT FORMAT:
|
||||
|
||||
If more information is needed:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
||||
"keyword": "email_provider",
|
||||
"example": "Gmail"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If ready to proceed:
|
||||
```json
|
||||
{{
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{{
|
||||
"step_number": 1,
|
||||
"block_name": "AgentShortTextInputBlock",
|
||||
"description": "Get the URL of the content to analyze.",
|
||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
"""
|
||||
|
||||
GENERATION_PROMPT = """
|
||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
||||
|
||||
---
|
||||
|
||||
NODES:
|
||||
Each node must include:
|
||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
||||
- `block_id`: The block identifier (must match an Allowed Block)
|
||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
||||
- `metadata`: Must contain:
|
||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
||||
|
||||
---
|
||||
|
||||
LINKS:
|
||||
Each link connects a source node's output to a sink node's input:
|
||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
||||
- `source_id`: ID of the source node
|
||||
- `source_name`: Output field name from the source block
|
||||
- `sink_id`: ID of the sink node
|
||||
- `sink_name`: Input field name on the sink block
|
||||
- `is_static`: true only if source block has static_output: true
|
||||
|
||||
CRITICAL: All IDs must be valid UUID v4 format!
|
||||
|
||||
---
|
||||
|
||||
AGENT (GRAPH):
|
||||
Wrap nodes and links in:
|
||||
- `id`: UUID of the agent
|
||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
||||
- `description`: Short, generic description
|
||||
- `nodes`: List of all nodes
|
||||
- `links`: List of all links
|
||||
- `version`: 1
|
||||
- `is_active`: true
|
||||
|
||||
---
|
||||
|
||||
TIPS:
|
||||
- All required_input fields must be provided via input_default or a valid link
|
||||
- Ensure consistent source_id and sink_id references
|
||||
- Avoid dangling links
|
||||
- Input/output pins must match block schemas
|
||||
- Do not invent unknown block_ids
|
||||
|
||||
---
|
||||
|
||||
ALLOWED BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
||||
"""
|
||||
|
||||
PATCH_PROMPT = """
|
||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
||||
|
||||
CURRENT AGENT:
|
||||
{current_agent}
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
PATCH FORMAT:
|
||||
Return a JSON object with the following structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"type": "patch",
|
||||
"intent": "Brief description of what the patch does",
|
||||
"patches": [
|
||||
{{
|
||||
"type": "modify",
|
||||
"node_id": "uuid-of-node-to-modify",
|
||||
"changes": {{
|
||||
"input_default": {{"field": "new_value"}},
|
||||
"metadata": {{"customized_name": "New Name"}}
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"type": "add",
|
||||
"new_nodes": [
|
||||
{{
|
||||
"id": "new-uuid",
|
||||
"block_id": "block-uuid",
|
||||
"input_default": {{}},
|
||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
||||
}}
|
||||
],
|
||||
"new_links": [
|
||||
{{
|
||||
"id": "link-uuid",
|
||||
"source_id": "source-node-id",
|
||||
"source_name": "output_field",
|
||||
"sink_id": "sink-node-id",
|
||||
"sink_name": "input_field"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
{{
|
||||
"type": "remove",
|
||||
"node_ids": ["uuid-of-node-to-remove"],
|
||||
"link_ids": ["uuid-of-link-to-remove"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If you need more information, return:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "What specific change do you want?",
|
||||
"keyword": "change_type",
|
||||
"example": "Add error handling"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
||||
"""
|
||||
@@ -0,0 +1,269 @@
|
||||
"""External Agent Generator service client.
|
||||
|
||||
This module provides a client for communicating with the external Agent Generator
|
||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def _get_settings() -> Settings:
|
||||
"""Get or create settings singleton."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def is_external_service_configured() -> bool:
|
||||
"""Check if external Agent Generator service is configured."""
|
||||
settings = _get_settings()
|
||||
return bool(settings.config.agentgenerator_host)
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL for the external service."""
|
||||
settings = _get_settings()
|
||||
host = settings.config.agentgenerator_host
|
||||
port = settings.config.agentgenerator_port
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
"""Get or create the HTTP client for the external service."""
|
||||
global _client
|
||||
if _client is None:
|
||||
settings = _get_settings()
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=_get_base_url(),
|
||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str, context: str = ""
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
Or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build the request payload
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if context:
|
||||
# The external service uses user_instruction for additional context
|
||||
payload["user_instruction"] = context
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": data.get("reason"),
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return None
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/api/blocks")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error("External service returned error getting blocks")
|
||||
return None
|
||||
|
||||
return data.get("blocks", [])
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error getting blocks from external service: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def health_check() -> bool:
|
||||
"""Check if the external service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/health")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||
except Exception as e:
|
||||
logger.warning(f"External agent generator health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def close_client() -> None:
|
||||
"""Close the HTTP client."""
|
||||
global _client
|
||||
if _client is not None:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
@@ -1,213 +0,0 @@
|
||||
"""Utilities for agent generation."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# UUID validation regex
|
||||
UUID_REGEX = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
||||
)
|
||||
|
||||
# Block IDs for various fixes
|
||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
||||
]
|
||||
|
||||
|
||||
def is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID v4."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
||||
"""Extract compact type info from a JSON schema properties dict.
|
||||
|
||||
Returns a dict of {field_name: type_string} for essential info only.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for name, prop in props.items():
|
||||
# Skip internal/complex fields
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get type string
|
||||
type_str = prop.get("type", "any")
|
||||
|
||||
# Handle anyOf/oneOf (optional types)
|
||||
if "anyOf" in prop:
|
||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
||||
type_str = "|".join(types) if types else "any"
|
||||
elif "allOf" in prop:
|
||||
type_str = "object"
|
||||
|
||||
# Add array item type if present
|
||||
if type_str == "array" and "items" in prop:
|
||||
items = prop["items"]
|
||||
if isinstance(items, dict):
|
||||
item_type = items.get("type", "any")
|
||||
type_str = f"array[{item_type}]"
|
||||
|
||||
result[name] = type_str
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
||||
"""Generate compact block summaries for prompts.
|
||||
|
||||
Args:
|
||||
include_schemas: Whether to include input/output type info
|
||||
|
||||
Returns:
|
||||
Formatted string of block summaries (compact format)
|
||||
"""
|
||||
blocks = get_blocks()
|
||||
summaries = []
|
||||
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
name = block.name
|
||||
desc = getattr(block, "description", "") or ""
|
||||
|
||||
# Truncate description
|
||||
if len(desc) > 150:
|
||||
desc = desc[:147] + "..."
|
||||
|
||||
if not include_schemas:
|
||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
||||
else:
|
||||
# Compact format with type info only
|
||||
inputs = {}
|
||||
outputs = {}
|
||||
required = []
|
||||
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
schema = block.input_schema.jsonschema()
|
||||
inputs = _compact_schema(schema)
|
||||
required = schema.get("required", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
schema = block.output_schema.jsonschema()
|
||||
outputs = _compact_schema(schema)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact line format
|
||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
||||
|
||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
||||
|
||||
line = f"- {name} (id: {block_id}): {desc}"
|
||||
if in_str:
|
||||
line += f"\n in: {{{in_str}}}{req_str}"
|
||||
if out_str:
|
||||
line += f"\n out: {{{out_str}}}{static}"
|
||||
|
||||
summaries.append(line)
|
||||
|
||||
return "\n".join(summaries)
|
||||
|
||||
|
||||
def get_blocks_info() -> list[dict[str, Any]]:
|
||||
"""Get block information with schemas for validation and fixing."""
|
||||
blocks = get_blocks()
|
||||
blocks_info = []
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
blocks_info.append(
|
||||
{
|
||||
"id": block_id,
|
||||
"name": block.name,
|
||||
"description": getattr(block, "description", ""),
|
||||
"categories": getattr(block, "categories", []),
|
||||
"staticOutput": getattr(block, "static_output", False),
|
||||
"inputSchema": (
|
||||
block.input_schema.jsonschema()
|
||||
if hasattr(block, "input_schema")
|
||||
else {}
|
||||
),
|
||||
"outputSchema": (
|
||||
block.output_schema.jsonschema()
|
||||
if hasattr(block, "output_schema")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return blocks_info
|
||||
|
||||
|
||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Try fenced code block
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try raw text
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding {...} span
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding [...] span
|
||||
start = text.find("[")
|
||||
end = text.rfind("]")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,279 +0,0 @@
|
||||
"""Agent validator - Validates agent structure and connections."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .utils import get_blocks_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error: str) -> None:
|
||||
"""Add an error message."""
|
||||
self.errors.append(error)
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate all block IDs exist in the blocks library."""
|
||||
valid = True
|
||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate all node IDs referenced in links exist."""
|
||||
valid = True
|
||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not source_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if not sink_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate required inputs are provided."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
# Get linked inputs
|
||||
linked_inputs = {
|
||||
link["sink_name"]
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id
|
||||
}
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate linked data types are compatible."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
def get_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
||||
self.add_error(
|
||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate nested sink links (with _#_ notation)."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block = block_map.get(sink_node.get("block_id"))
|
||||
if not block:
|
||||
continue
|
||||
|
||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
||||
parent_schema = input_props.get(parent)
|
||||
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if not parent_schema.get("additionalProperties"):
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and child in parent_schema.get("properties", {})
|
||||
):
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate prompts don't have spaces in template variables."""
|
||||
valid = True
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
input_default = node.get("input_default", {})
|
||||
prompt = input_default.get("prompt", "")
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
continue
|
||||
|
||||
# Find {{...}} with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
self.add_error(
|
||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Run all validations.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
self.errors = []
|
||||
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
checks = [
|
||||
self.validate_block_existence(agent, blocks_info),
|
||||
self.validate_link_node_references(agent),
|
||||
self.validate_required_inputs(agent, blocks_info),
|
||||
self.validate_data_type_compatibility(agent, blocks_info),
|
||||
self.validate_nested_sink_links(agent, blocks_info),
|
||||
self.validate_prompt_spaces(agent),
|
||||
]
|
||||
|
||||
all_passed = all(checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful")
|
||||
return True, None
|
||||
|
||||
error_message = "Agent validation failed:\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
||||
return False, error_message
|
||||
|
||||
|
||||
def validate_agent(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Convenience function to validate an agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
validator = AgentValidator()
|
||||
return validator.validate(agent, blocks_info)
|
||||
@@ -8,12 +8,10 @@ from langfuse import observe
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_all_fixes,
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -27,9 +25,6 @@ from .models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for agent generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
@@ -91,9 +86,8 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON from the steps
|
||||
3. Apply fixes to correct common LLM errors
|
||||
4. Preview or save based on the save parameter
|
||||
2. Generate agent JSON (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
@@ -110,11 +104,13 @@ class CreateAgentTool(BaseTool):
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -171,72 +167,32 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate agent JSON with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
agent_json = None
|
||||
validation_errors = None
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate agent (include validation errors from previous attempt)
|
||||
if attempt == 0:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_instructions = {
|
||||
**decomposition_result,
|
||||
"previous_errors": validation_errors,
|
||||
"retry_instructions": (
|
||||
"The previous generation had validation errors. "
|
||||
"Please fix these issues in the new generation:\n"
|
||||
f"{validation_errors}"
|
||||
),
|
||||
}
|
||||
agent_json = await generate_agent(retry_instructions)
|
||||
|
||||
if agent_json is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 3: Apply fixes to correct common errors
|
||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
||||
|
||||
# Step 4: Validate the agent
|
||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||
try:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the workflow."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
if agent_json is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 4: Preview or save
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
|
||||
@@ -8,13 +8,10 @@ from langfuse import observe
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_agent_patch,
|
||||
apply_all_fixes,
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -28,9 +25,6 @@ from .models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for patch generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
@@ -43,7 +37,7 @@ class EditAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates a patch to update the agent while preserving unchanged parts."
|
||||
"Generates updates to the agent while preserving unchanged parts."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -98,9 +92,8 @@ class EditAgentTool(BaseTool):
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate a patch based on the requested changes
|
||||
3. Apply the patch to create an updated agent
|
||||
4. Preview or save based on the save parameter
|
||||
2. Generate updated agent (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
@@ -137,121 +130,58 @@ class EditAgentTool(BaseTool):
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate patch with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
updated_agent = None
|
||||
validation_errors = None
|
||||
intent = "Applied requested changes"
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate patch (include validation errors from previous attempt)
|
||||
try:
|
||||
if attempt == 0:
|
||||
patch_result = await generate_agent_patch(
|
||||
update_request, current_agent
|
||||
)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_request = (
|
||||
f"{update_request}\n\n"
|
||||
f"IMPORTANT: The previous edit had validation errors. "
|
||||
f"Please fix these issues:\n{validation_errors}"
|
||||
)
|
||||
patch_result = await generate_agent_patch(
|
||||
retry_request, current_agent
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if patch_result is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Patch generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if patch_result.get("type") == "clarifying_questions":
|
||||
questions = patch_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 3: Apply patch and fixes
|
||||
try:
|
||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
||||
except Exception as e:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to apply changes: {str(e)}",
|
||||
error="patch_apply_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
validation_errors = str(e)
|
||||
continue
|
||||
|
||||
# Step 4: Validate the updated agent
|
||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
||||
intent = patch_result.get("intent", "Applied requested changes")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||
try:
|
||||
result = await generate_agent_patch(update_request, current_agent)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent editing is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Updated agent has validation errors after "
|
||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the changes."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Update generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
||||
assert updated_agent is not None
|
||||
# Check if LLM returned clarifying questions
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result is the updated agent JSON
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 5: Preview or save
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. Changes: {intent}. "
|
||||
f"I've updated the agent. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
@@ -277,10 +207,7 @@ class EditAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
||||
f"Changes: {intent}"
|
||||
),
|
||||
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
|
||||
@@ -393,7 +393,6 @@ async def get_creators(
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
operation_id="getV2GetCreatorDetails",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.llm_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -38,11 +37,9 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm.routes as public_llm_routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -112,27 +109,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
# migrate_llm_models uses registry default model
|
||||
from backend.blocks.llm import LlmModel
|
||||
|
||||
default_model_slug = llm_registry.get_default_model_slug()
|
||||
if default_model_slug:
|
||||
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
|
||||
else:
|
||||
logger.warning("Skipping LLM model migration: no default model available")
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
@@ -317,16 +298,6 @@ app.include_router(
|
||||
tags=["v2", "executions", "review"],
|
||||
prefix="/api/review",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.llm_routes.router,
|
||||
tags=["v2", "admin", "llm"],
|
||||
prefix="/api/llm/admin",
|
||||
)
|
||||
app.include_router(
|
||||
public_llm_routes.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
|
||||
@@ -77,39 +77,7 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
async def registry_refresh_worker():
|
||||
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
|
||||
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
|
||||
)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if (
|
||||
message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info(
|
||||
"Broadcasting LLM registry refresh to all WebSocket clients"
|
||||
)
|
||||
await manager.broadcast_to_all(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={
|
||||
"type": "LLM_REGISTRY_REFRESH",
|
||||
"event": "registry_updated",
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
execution_worker(),
|
||||
notification_worker(),
|
||||
registry_refresh_worker(),
|
||||
)
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -9,7 +10,6 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
llm_model_schema_extra,
|
||||
)
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
@@ -50,10 +50,9 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -83,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -4,19 +4,17 @@ import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Iterable, List, Literal, Optional
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -24,7 +22,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.llm_registry import ModelMetadata
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -69,117 +66,114 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
def AICredentialsField() -> AICredentials:
|
||||
"""
|
||||
Returns a CredentialsField for LLM providers.
|
||||
The discriminator_mapping will be refreshed when the schema is generated
|
||||
if it's empty, ensuring the LLM registry is loaded.
|
||||
"""
|
||||
# Get the mapping now - it may be empty initially, but will be refreshed
|
||||
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
||||
mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
discriminator="model",
|
||||
discriminator_mapping=mapping, # May be empty initially, refreshed later
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def llm_model_schema_extra() -> dict[str, Any]:
|
||||
return {"options": llm_registry.get_llm_model_schema_options()}
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
|
||||
|
||||
class LlmModelMeta(type):
|
||||
"""
|
||||
Metaclass for LlmModel that enables attribute-style access to dynamic models.
|
||||
|
||||
This allows code like `LlmModel.GPT4O` to work by converting the attribute
|
||||
name to a slug format:
|
||||
- GPT4O -> gpt-4o
|
||||
- GPT4O_MINI -> gpt-4o-mini
|
||||
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
|
||||
"""
|
||||
|
||||
def __getattr__(cls, name: str):
|
||||
# Don't intercept private/dunder attributes
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
|
||||
|
||||
# Convert attribute name to slug format:
|
||||
# 1. Lowercase: GPT4O -> gpt4o
|
||||
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
|
||||
# 3. Insert hyphen between letter and digit: gpt4o -> gpt-4o
|
||||
slug = name.lower().replace("_", "-")
|
||||
slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
|
||||
|
||||
return cls(slug)
|
||||
|
||||
def __iter__(cls):
|
||||
"""Iterate over all models from the registry.
|
||||
|
||||
Yields LlmModel instances for each model in the dynamic registry.
|
||||
Used by __get_pydantic_json_schema__ to build model metadata.
|
||||
"""
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
yield cls(model.slug)
|
||||
class LlmModelMeta(EnumMeta):
|
||||
pass
|
||||
|
||||
|
||||
class LlmModel(str, metaclass=LlmModelMeta):
|
||||
"""
|
||||
Dynamic LLM model type that accepts any model slug from the registry.
|
||||
|
||||
This is a string subclass (not an Enum) that allows any model slug value.
|
||||
All models are managed via the LLM Registry in the database.
|
||||
|
||||
Usage:
|
||||
model = LlmModel("gpt-4o") # Direct construction
|
||||
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
|
||||
model.value # Returns the slug string
|
||||
model.provider # Returns the provider from registry
|
||||
"""
|
||||
|
||||
def __new__(cls, value: str):
|
||||
if isinstance(value, LlmModel):
|
||||
return value
|
||||
return str.__new__(cls, value)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""
|
||||
Tell Pydantic how to validate LlmModel.
|
||||
|
||||
Accepts strings and converts them to LlmModel instances.
|
||||
"""
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls, # The validator function (LlmModel constructor)
|
||||
core_schema.str_schema(), # Accept string input
|
||||
serialization=core_schema.to_string_ser_schema(), # Serialize as string
|
||||
)
|
||||
|
||||
@property
|
||||
def value(self) -> str:
|
||||
"""Return the model slug (for compatibility with enum-style access)."""
|
||||
return str(self)
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "LlmModel":
|
||||
"""
|
||||
Get the default model from the registry.
|
||||
|
||||
Returns the recommended model if set, otherwise gpt-4o if available
|
||||
and enabled, otherwise the first enabled model from the registry.
|
||||
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
|
||||
"""
|
||||
from backend.data.llm_registry import get_default_model_slug
|
||||
|
||||
slug = get_default_model_slug()
|
||||
if slug is None:
|
||||
# Registry is empty (e.g., at module import time before DB connection).
|
||||
# Fall back to gpt-4o for backward compatibility.
|
||||
slug = "gpt-4o"
|
||||
return cls(slug)
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
|
||||
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
||||
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
|
||||
# Groq models
|
||||
LLAMA3_3_70B = "llama-3.3-70b-versatile"
|
||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_3 = "llama3.3"
|
||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
||||
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
||||
# v0 by Vercel models
|
||||
V0_1_5_MD = "v0-1.5-md"
|
||||
V0_1_5_LG = "v0-1.5-lg"
|
||||
V0_1_0_MD = "v0-1.0-md"
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||
@@ -187,12 +181,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
llm_model_metadata = {}
|
||||
for model in cls:
|
||||
model_name = model.value
|
||||
# Use registry directly with None check to gracefully handle
|
||||
# missing metadata during startup/import before registry is populated
|
||||
metadata = llm_registry.get_llm_model_metadata(model_name)
|
||||
if metadata is None:
|
||||
# Skip models without metadata (registry not yet populated)
|
||||
continue
|
||||
metadata = model.metadata
|
||||
llm_model_metadata[model_name] = {
|
||||
"creator": metadata.creator_name,
|
||||
"creator_name": metadata.creator_name,
|
||||
@@ -208,12 +197,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
raise ValueError(
|
||||
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
|
||||
)
|
||||
return MODEL_METADATA[self]
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
@@ -228,11 +212,300 @@ class LlmModel(str, metaclass=LlmModelMeta):
|
||||
return self.metadata.max_output_tokens
|
||||
|
||||
|
||||
# MODEL_METADATA removed - all models now come from the database via llm_registry
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
|
||||
LlmModel.O3_MINI: ModelMetadata(
|
||||
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
|
||||
), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata(
|
||||
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
|
||||
), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata(
|
||||
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
|
||||
), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
|
||||
),
|
||||
LlmModel.GPT5_1: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT5: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_MINI: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_NANO: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata(
|
||||
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT41: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT41_MINI: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
LlmModel.GPT4O: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
|
||||
), # gpt-4o-2024-08-06
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||
), # claude-3-haiku-20240307
|
||||
# https://docs.aimlapi.com/api-overview/model-database/text-models
|
||||
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
|
||||
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
|
||||
"aiml_api",
|
||||
128000,
|
||||
40000,
|
||||
"Llama 3.1 Nemotron 70B Instruct",
|
||||
"AI/ML",
|
||||
"Nvidia",
|
||||
1,
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
|
||||
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
# https://console.groq.com/docs/models
|
||||
LlmModel.LLAMA3_3_70B: ModelMetadata(
|
||||
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata(
|
||||
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
|
||||
),
|
||||
# https://ollama.com/library
|
||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65535,
|
||||
"Gemini 2.5 Flash Lite Preview 06.17",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
8192,
|
||||
"Gemini 2.0 Flash Lite 001",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
|
||||
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
|
||||
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
16000,
|
||||
"Sonar Deep Research",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
3,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||
"open_router",
|
||||
131000,
|
||||
4096,
|
||||
"Hermes 3 Llama 3.1 405B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||
"open_router",
|
||||
12288,
|
||||
12288,
|
||||
"Hermes 3 Llama 3.1 70B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
|
||||
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
|
||||
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.KIMI_K2: ModelMetadata(
|
||||
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
|
||||
),
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
262144,
|
||||
"Qwen 3 235B A22B Thinking 2507",
|
||||
"OpenRouter",
|
||||
"Qwen",
|
||||
1,
|
||||
),
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Scout 17B 16E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Maverick 17B 128E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
|
||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
|
||||
}
|
||||
|
||||
# Default model constant for backward compatibility
|
||||
# Uses the dynamic registry to get the default model
|
||||
DEFAULT_LLM_MODEL = LlmModel.default()
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -361,98 +634,19 @@ async def llm_call(
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
# Get model metadata and check if enabled - with fallback support
|
||||
# The model we'll actually use (may differ if original is disabled)
|
||||
model_to_use = llm_model.value
|
||||
|
||||
# Check if model is in registry and if it's enabled
|
||||
from backend.data.llm_registry import (
|
||||
get_fallback_model_for_disabled,
|
||||
get_model_info,
|
||||
)
|
||||
|
||||
model_info = get_model_info(llm_model.value)
|
||||
|
||||
if model_info and not model_info.is_enabled:
|
||||
# Model is disabled - try to find a fallback from the same provider
|
||||
fallback = get_fallback_model_for_disabled(llm_model.value)
|
||||
if fallback:
|
||||
logger.warning(
|
||||
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
|
||||
)
|
||||
model_to_use = fallback.slug
|
||||
# Use fallback model's metadata
|
||||
provider = fallback.metadata.provider
|
||||
context_window = fallback.metadata.context_window
|
||||
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
|
||||
else:
|
||||
# No fallback available - raise error
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' is disabled and no fallback model "
|
||||
f"from the same provider is available. Please enable the model or "
|
||||
f"select a different model in the block configuration."
|
||||
)
|
||||
else:
|
||||
# Model is enabled or not in registry (legacy/static model)
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
except ValueError:
|
||||
# Model not in cache - try refreshing the registry once if we have DB access
|
||||
logger.warning(f"Model {llm_model.value} not found in registry cache")
|
||||
|
||||
# Try refreshing the registry if we have database access
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
try:
|
||||
logger.info(
|
||||
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
|
||||
)
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Try again after refresh
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
logger.info(
|
||||
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
|
||||
)
|
||||
except ValueError:
|
||||
# Still not found after refresh
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry after refresh. "
|
||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
||||
)
|
||||
except Exception as refresh_exc:
|
||||
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
||||
) from refresh_exc
|
||||
else:
|
||||
# No DB access (e.g., in executor without direct DB connection)
|
||||
# The registry should have been loaded on startup
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry cache. "
|
||||
"The registry may need to be refreshed. Please contact support or try again later."
|
||||
)
|
||||
|
||||
# Create effective model for model-specific parameter resolution (e.g., o-series check)
|
||||
# This uses the resolved model_to_use which may differ from llm_model if fallback occurred
|
||||
effective_model = LlmModel(model_to_use)
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=context_window // 2,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
lossy_ok=True,
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
# model_max_output already set above
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
@@ -463,14 +657,14 @@ async def llm_call(
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
@@ -517,7 +711,7 @@ async def llm_call(
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
@@ -581,7 +775,7 @@ async def llm_call(
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -603,7 +797,7 @@ async def llm_call(
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = await client.generate(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
@@ -625,7 +819,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -633,7 +827,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -667,7 +861,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -675,7 +869,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -702,7 +896,7 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "aiml_api":
|
||||
client = openai.AsyncOpenAI(
|
||||
client = openai.OpenAI(
|
||||
base_url="https://api.aimlapi.com/v2",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
default_headers={
|
||||
@@ -712,8 +906,8 @@ async def llm_call(
|
||||
},
|
||||
)
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
completion = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@@ -741,11 +935,11 @@ async def llm_call(
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
effective_model, parallel_tool_calls
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -796,10 +990,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
@@ -862,7 +1055,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1228,10 +1421,9 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
@@ -1325,9 +1517,8 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for summarizing the text.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
title="Focus",
|
||||
@@ -1543,9 +1734,8 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for the conversation.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -1582,7 +1772,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1645,10 +1835,9 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=LlmModel.default,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_retries: int = SchemaField(
|
||||
@@ -1703,7 +1892,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
@@ -226,10 +226,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default_factory=llm.LlmModel.default,
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm.llm_model_schema_extra(),
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
multiple_tool_calls: bool = SchemaField(
|
||||
|
||||
@@ -10,13 +10,13 @@ import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.data import llm_registry
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = self.metadata
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
@@ -107,23 +107,19 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.metadata.provider
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
# Fallback to LlmModel enum if registry lookup fails
|
||||
return LlmModel(self.value).metadata
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return self.metadata.context_window
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return self.metadata.max_output_tokens
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
|
||||
@@ -25,7 +25,6 @@ from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.llm_registry import update_schema_with_llm_registry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
@@ -144,59 +143,35 @@ class BlockInfo(BaseModel):
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
|
||||
|
||||
@classmethod
|
||||
def clear_schema_cache(cls) -> None:
|
||||
"""Clear the cached JSON schema for this class."""
|
||||
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||
cls.cached_jsonschema = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def clear_all_schema_caches() -> None:
|
||||
"""Clear cached JSON schemas for all BlockSchema subclasses."""
|
||||
|
||||
def clear_recursive(cls: type) -> None:
|
||||
"""Recursively clear cache for class and all subclasses."""
|
||||
if hasattr(cls, "clear_schema_cache"):
|
||||
cls.clear_schema_cache()
|
||||
for subclass in cls.__subclasses__():
|
||||
clear_recursive(subclass)
|
||||
|
||||
clear_recursive(BlockSchema)
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
# Generate schema if not cached
|
||||
if not cls.cached_jsonschema:
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
if cls.cached_jsonschema:
|
||||
return cls.cached_jsonschema
|
||||
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next(
|
||||
(k for k in keys if k in obj and len(obj[k]) == 1), None
|
||||
)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
|
||||
return obj
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
return obj
|
||||
|
||||
# Always post-process to ensure LLM registry data is up-to-date
|
||||
# This refreshes model options and discriminator mappings even if schema was cached
|
||||
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@@ -259,7 +234,7 @@ class BlockSchema(BaseModel):
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||
cls.cached_jsonschema = None
|
||||
cls.cached_jsonschema = {}
|
||||
|
||||
credentials_fields = cls.get_credentials_fields()
|
||||
|
||||
@@ -896,28 +871,6 @@ def is_block_auth_configured(
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# This ensures the registry cache is populated even in executor context
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
|
||||
# Only refresh if we have DB access (check if Prisma is connected)
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
logger.info("LLM registry refreshed during block initialization")
|
||||
else:
|
||||
logger.warning(
|
||||
"Prisma not connected, skipping LLM registry refresh during block initialization"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM registry during block initialization: %s", exc
|
||||
)
|
||||
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
|
||||
@@ -24,18 +23,19 @@ from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
AIListGeneratorBlock,
|
||||
AIStructuredResponseGeneratorBlock,
|
||||
AITextGeneratorBlock,
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
@@ -55,63 +55,210 @@ from backend.integrations.credentials_store import (
|
||||
v0_credentials,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
|
||||
PROVIDER_CREDENTIALS = {
|
||||
"openai": openai_credentials,
|
||||
"anthropic": anthropic_credentials,
|
||||
"groq": groq_credentials,
|
||||
"open_router": open_router_credentials,
|
||||
"llama_api": llama_api_credentials,
|
||||
"aiml_api": aiml_api_credentials,
|
||||
"v0": v0_credentials,
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT41_MINI: 1,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
LlmModel.GROK_CODE_FAST_1: 1,
|
||||
LlmModel.KIMI_K2: 1,
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
||||
LlmModel.QWEN3_CODER: 9,
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: 1,
|
||||
LlmModel.V0_1_5_LG: 2,
|
||||
LlmModel.V0_1_0_MD: 1,
|
||||
}
|
||||
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
# All LLM costs now come from the database via llm_registry
|
||||
|
||||
LLM_COST: list[BlockCost] = []
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_COST:
|
||||
raise ValueError(f"Missing MODEL_COST for model: {model}")
|
||||
|
||||
|
||||
def _build_llm_costs_from_registry() -> list[BlockCost]:
|
||||
"""Build BlockCost list from all models in the LLM registry."""
|
||||
costs: list[BlockCost] = []
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
for cost in model.costs:
|
||||
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
|
||||
if not credentials:
|
||||
logger.warning(
|
||||
"Skipping cost entry for %s due to unknown credentials provider %s",
|
||||
model.slug,
|
||||
cost.credential_provider,
|
||||
)
|
||||
continue
|
||||
cost_filter = {
|
||||
"model": model.slug,
|
||||
LLM_COST = (
|
||||
# Anthropic Models
|
||||
[
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": credentials.id,
|
||||
"provider": credentials.provider,
|
||||
"type": credentials.type,
|
||||
"id": anthropic_credentials.id,
|
||||
"provider": anthropic_credentials.provider,
|
||||
"type": anthropic_credentials.type,
|
||||
},
|
||||
}
|
||||
costs.append(
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter=cost_filter,
|
||||
cost_amount=cost.credit_cost,
|
||||
)
|
||||
)
|
||||
return costs
|
||||
|
||||
|
||||
def refresh_llm_costs() -> None:
|
||||
"""Refresh LLM costs from the registry. All costs now come from the database."""
|
||||
LLM_COST.clear()
|
||||
LLM_COST.extend(_build_llm_costs_from_registry())
|
||||
|
||||
|
||||
# Initial load will happen after registry is refreshed at startup
|
||||
# Don't call refresh_llm_costs() here - it will be called after registry refresh
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "anthropic"
|
||||
]
|
||||
# OpenAI Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
"provider": openai_credentials.provider,
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "openai"
|
||||
]
|
||||
# Groq Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {"id": groq_credentials.id},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "groq"
|
||||
]
|
||||
# Open Router Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "open_router"
|
||||
]
|
||||
# Llama API Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": llama_api_credentials.id,
|
||||
"provider": llama_api_credentials.provider,
|
||||
"type": llama_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "llama_api"
|
||||
]
|
||||
# v0 by Vercel Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": v0_credentials.id,
|
||||
"provider": v0_credentials.provider,
|
||||
"type": v0_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "v0"
|
||||
]
|
||||
# AI/ML Api Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": aiml_api_credentials.id,
|
||||
"provider": aiml_api_credentials.provider,
|
||||
"type": aiml_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "aiml_api"
|
||||
]
|
||||
)
|
||||
|
||||
# =============== This is the exhaustive list of cost for each Block =============== #
|
||||
|
||||
|
||||
@@ -1511,10 +1511,8 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Get all model slugs from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
|
||||
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
|
||||
# Update each block
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
LLM Registry module for managing LLM models, providers, and costs dynamically.
|
||||
|
||||
This module provides a database-driven registry system for LLM models,
|
||||
replacing hardcoded model configurations with a flexible admin-managed system.
|
||||
"""
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
# Re-export for backwards compatibility
|
||||
from backend.data.llm_registry.notifications import (
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_dynamic_model_slugs,
|
||||
get_fallback_model_for_disabled,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_cost,
|
||||
get_llm_model_metadata,
|
||||
get_llm_model_schema_options,
|
||||
get_model_info,
|
||||
is_model_enabled,
|
||||
iter_dynamic_models,
|
||||
refresh_llm_registry,
|
||||
register_static_costs,
|
||||
register_static_metadata,
|
||||
)
|
||||
from backend.data.llm_registry.schema_utils import (
|
||||
is_llm_model_field,
|
||||
refresh_llm_discriminator_mapping,
|
||||
refresh_llm_model_options,
|
||||
update_schema_with_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Registry functions
|
||||
"get_all_model_slugs_for_validation",
|
||||
"get_default_model_slug",
|
||||
"get_dynamic_model_slugs",
|
||||
"get_fallback_model_for_disabled",
|
||||
"get_llm_discriminator_mapping",
|
||||
"get_llm_model_cost",
|
||||
"get_llm_model_metadata",
|
||||
"get_llm_model_schema_options",
|
||||
"get_model_info",
|
||||
"is_model_enabled",
|
||||
"iter_dynamic_models",
|
||||
"refresh_llm_registry",
|
||||
"register_static_costs",
|
||||
"register_static_metadata",
|
||||
# Notifications
|
||||
"REGISTRY_REFRESH_CHANNEL",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Schema utilities
|
||||
"is_llm_model_field",
|
||||
"refresh_llm_discriminator_mapping",
|
||||
"refresh_llm_model_options",
|
||||
"update_schema_with_llm_registry",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Type definitions for LLM model metadata."""
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
"""Metadata for an LLM model.
|
||||
|
||||
Attributes:
|
||||
provider: The provider identifier (e.g., "openai", "anthropic")
|
||||
context_window: Maximum context window size in tokens
|
||||
max_output_tokens: Maximum output tokens (None if unlimited)
|
||||
display_name: Human-readable name for the model
|
||||
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
|
||||
creator_name: Name of the organization that created the model
|
||||
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
|
||||
"""
|
||||
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
Redis pub/sub notifications for LLM registry updates.
|
||||
|
||||
When models are added/updated/removed via the admin UI, this module
|
||||
publishes notifications to Redis that all executor services subscribe to,
|
||||
ensuring they refresh their registry cache in real-time.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis channel name for LLM registry refresh notifications
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""
|
||||
Publish a notification to Redis that the LLM registry has been updated.
|
||||
All executor services subscribed to this channel will refresh their registry.
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.info("Published LLM registry refresh notification to Redis")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to publish LLM registry refresh notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Any, # Async callable that takes no args
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to Redis notifications for LLM registry updates.
|
||||
This runs in a loop and processes messages as they arrive.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable to execute when a refresh notification is received
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications on channel: %s",
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
)
|
||||
|
||||
# Process messages in a loop
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info("Received LLM registry refresh notification")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Error refreshing LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", exc, exc_info=True
|
||||
)
|
||||
# Continue listening even if one message fails
|
||||
await asyncio.sleep(1)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to subscribe to LLM registry refresh notifications: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
@@ -1,388 +0,0 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
# Prisma Json type should always be a dict at runtime
|
||||
return dict(value) if value else {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCreator:
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
website_url: str | None
|
||||
logo_url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
|
||||
_static_metadata: dict[str, ModelMetadata] = {}
|
||||
_static_costs: dict[str, int] = {}
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_discriminator_mapping: dict[str, str] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
|
||||
"""Register static metadata for legacy models (deprecated)."""
|
||||
_static_metadata.update({str(key): value for key, value in metadata.items()})
|
||||
_refresh_cached_schema()
|
||||
|
||||
|
||||
def register_static_costs(costs: dict[Any, int]) -> None:
|
||||
"""Register static costs for legacy models (deprecated)."""
|
||||
_static_costs.update({str(key): value for key, value in costs.items()})
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
options: list[dict[str, str]] = []
|
||||
# Only include enabled models in the dropdown options
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
options.append(
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
)
|
||||
|
||||
for slug, metadata in _static_metadata.items():
|
||||
if slug in _dynamic_models:
|
||||
continue
|
||||
options.append(
|
||||
{
|
||||
"label": slug,
|
||||
"value": slug,
|
||||
"group": metadata.provider,
|
||||
"description": "",
|
||||
}
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.debug("Found %d LLM model records in database", len(records))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
dynamic: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display_name = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
# Creator name: prefer Creator.name, fallback to provider display name
|
||||
creator_name = (
|
||||
record.Creator.name if record.Creator else provider_display_name
|
||||
)
|
||||
# Price tier: default to 1 (cheapest) if not set
|
||||
price_tier = getattr(record, "priceTier", 1) or 1
|
||||
# Clamp to valid range 1-3
|
||||
price_tier = max(1, min(3, price_tier))
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display_name,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier, # type: ignore[arg-type]
|
||||
)
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Map creator if present
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
dynamic[record.slug] = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=_json_to_dict(record.capabilities),
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
provider_display_name=(
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
),
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
|
||||
# Atomic swap - build new structures then replace references
|
||||
# This ensures readers never see partially updated state
|
||||
global _dynamic_models
|
||||
_dynamic_models = dynamic
|
||||
_refresh_cached_schema()
|
||||
logger.info(
|
||||
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
|
||||
len(dynamic),
|
||||
sum(1 for m in dynamic.values() if m.is_enabled),
|
||||
sum(1 for m in dynamic.values() if not m.is_enabled),
|
||||
)
|
||||
|
||||
|
||||
def _refresh_cached_schema() -> None:
|
||||
"""Refresh cached schema options and discriminator mapping."""
|
||||
global _schema_options, _discriminator_mapping
|
||||
|
||||
# Build new structures
|
||||
new_options = _build_schema_options()
|
||||
new_mapping = {
|
||||
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
|
||||
}
|
||||
for slug, metadata in _static_metadata.items():
|
||||
new_mapping.setdefault(slug, metadata.provider)
|
||||
|
||||
# Atomic swap - replace references to ensure readers see consistent state
|
||||
_schema_options = new_options
|
||||
_discriminator_mapping = new_mapping
|
||||
|
||||
|
||||
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
|
||||
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].metadata
|
||||
return _static_metadata.get(slug)
|
||||
|
||||
|
||||
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
|
||||
"""Get model cost configuration by slug."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].costs
|
||||
cost_value = _static_costs.get(slug)
|
||||
if cost_value is None:
|
||||
return tuple()
|
||||
return (
|
||||
RegistryModelCost(
|
||||
credit_cost=cost_value,
|
||||
credential_provider="static",
|
||||
credential_id=None,
|
||||
credential_type=None,
|
||||
currency=None,
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_llm_model_schema_options() -> list[dict[str, str]]:
|
||||
"""
|
||||
Get schema options for LLM model selection dropdown.
|
||||
|
||||
Returns a copy of cached schema options that are refreshed when the registry is
|
||||
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_llm_discriminator_mapping() -> dict[str, str]:
|
||||
"""
|
||||
Get discriminator mapping for LLM models.
|
||||
|
||||
Returns a copy of cached discriminator mapping that is refreshed when the registry
|
||||
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return dict(_discriminator_mapping)
|
||||
|
||||
|
||||
def get_dynamic_model_slugs() -> set[str]:
|
||||
"""Get all dynamic model slugs from the registry."""
|
||||
return set(_dynamic_models.keys())
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> set[str]:
|
||||
"""
|
||||
Get ALL model slugs (both enabled and disabled) for validation purposes.
|
||||
|
||||
This is used for JSON schema enum validation - we need to accept any known
|
||||
model value (even disabled ones) so that existing graphs don't fail validation.
|
||||
The actual fallback/enforcement happens at runtime in llm_call().
|
||||
"""
|
||||
all_slugs = set(_dynamic_models.keys())
|
||||
all_slugs.update(_static_metadata.keys())
|
||||
return all_slugs
|
||||
|
||||
|
||||
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||
"""Iterate over all dynamic models in the registry."""
|
||||
return tuple(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
|
||||
"""
|
||||
Find a fallback model when the requested model is disabled.
|
||||
|
||||
Looks for an enabled model from the same provider. Prefers models with
|
||||
similar names or capabilities if possible.
|
||||
|
||||
Args:
|
||||
disabled_model_slug: The slug of the disabled model
|
||||
|
||||
Returns:
|
||||
An enabled RegistryModel from the same provider, or None if no fallback found
|
||||
"""
|
||||
disabled_model = _dynamic_models.get(disabled_model_slug)
|
||||
if not disabled_model:
|
||||
return None
|
||||
|
||||
provider = disabled_model.metadata.provider
|
||||
|
||||
# Find all enabled models from the same provider
|
||||
candidates = [
|
||||
model
|
||||
for model in _dynamic_models.values()
|
||||
if model.is_enabled and model.metadata.provider == provider
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Sort by: prefer models with similar context window, then by name
|
||||
candidates.sort(
|
||||
key=lambda m: (
|
||||
abs(m.metadata.context_window - disabled_model.metadata.context_window),
|
||||
m.display_name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def is_model_enabled(model_slug: str) -> bool:
|
||||
"""Check if a model is enabled in the registry."""
|
||||
model = _dynamic_models.get(model_slug)
|
||||
if not model:
|
||||
# Model not in registry - assume it's a static/legacy model and allow it
|
||||
return True
|
||||
return model.is_enabled
|
||||
|
||||
|
||||
def get_model_info(model_slug: str) -> RegistryModel | None:
|
||||
"""Get model info from the registry."""
|
||||
return _dynamic_models.get(model_slug)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""
|
||||
Get the default model slug to use for block defaults.
|
||||
|
||||
Returns the recommended model if set (configured via admin UI),
|
||||
otherwise returns the first enabled model alphabetically.
|
||||
Returns None if no models are available or enabled.
|
||||
"""
|
||||
# Return the recommended model if one is set and enabled
|
||||
for model in _dynamic_models.values():
|
||||
if model.is_recommended and model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
# No recommended model set - find first enabled model alphabetically
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
logger.warning(
|
||||
"No recommended model set, using '%s' as default",
|
||||
model.slug,
|
||||
)
|
||||
return model.slug
|
||||
|
||||
# No enabled models available
|
||||
if _dynamic_models:
|
||||
logger.error(
|
||||
"No enabled models found in registry (%d models registered but all disabled)",
|
||||
len(_dynamic_models),
|
||||
)
|
||||
else:
|
||||
logger.error("No models registered in LLM registry")
|
||||
|
||||
return None
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
Helper utilities for LLM registry integration with block schemas.
|
||||
|
||||
This module handles the dynamic injection of discriminator mappings
|
||||
and model options from the LLM registry into block schemas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_schema_options,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
|
||||
"""
|
||||
Check if a field is an LLM model selection field.
|
||||
|
||||
Returns True if the field has 'options' in json_schema_extra
|
||||
(set by llm_model_schema_extra() in blocks/llm.py).
|
||||
"""
|
||||
if not hasattr(field_info, "json_schema_extra"):
|
||||
return False
|
||||
|
||||
extra = field_info.json_schema_extra
|
||||
if isinstance(extra, dict):
|
||||
return "options" in extra
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh LLM model options from the registry.
|
||||
|
||||
Updates 'options' (for frontend dropdown) to show only enabled models,
|
||||
but keeps the 'enum' (for validation) inclusive of ALL known models.
|
||||
|
||||
This is important because:
|
||||
- Options: What users see in the dropdown (enabled models only)
|
||||
- Enum: What values pass validation (all known models, including disabled)
|
||||
|
||||
Existing graphs may have disabled models selected - they should pass validation
|
||||
and the fallback logic in llm_call() will handle using an alternative model.
|
||||
"""
|
||||
fresh_options = get_llm_model_schema_options()
|
||||
if not fresh_options:
|
||||
return
|
||||
|
||||
# Update options array (UI dropdown) - only enabled models
|
||||
if "options" in field_schema:
|
||||
field_schema["options"] = fresh_options
|
||||
|
||||
all_known_slugs = get_all_model_slugs_for_validation()
|
||||
if all_known_slugs and "enum" in field_schema:
|
||||
existing_enum = set(field_schema.get("enum", []))
|
||||
combined_enum = existing_enum | all_known_slugs
|
||||
field_schema["enum"] = sorted(combined_enum)
|
||||
|
||||
# Set the default value from the registry (gpt-4o if available, else first enabled)
|
||||
# This ensures new blocks have a sensible default pre-selected
|
||||
default_slug = get_default_model_slug()
|
||||
if default_slug:
|
||||
field_schema["default"] = default_slug
|
||||
|
||||
|
||||
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh discriminator_mapping for fields that use model-based discrimination.
|
||||
|
||||
The discriminator is already set when AICredentialsField() creates the field.
|
||||
We only need to refresh the mapping when models are added/removed.
|
||||
"""
|
||||
if field_schema.get("discriminator") != "model":
|
||||
return
|
||||
|
||||
# Always refresh the mapping to get latest models
|
||||
fresh_mapping = get_llm_discriminator_mapping()
|
||||
if fresh_mapping is not None:
|
||||
field_schema["discriminator_mapping"] = fresh_mapping
|
||||
|
||||
|
||||
def update_schema_with_llm_registry(
|
||||
schema: dict[str, Any], model_class: type | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update a JSON schema with current LLM registry data.
|
||||
|
||||
Refreshes:
|
||||
1. Model options for LLM model selection fields (dropdown choices)
|
||||
2. Discriminator mappings for credentials fields (model → provider)
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to update (mutated in-place)
|
||||
model_class: The Pydantic model class (optional, for field introspection)
|
||||
"""
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
if not isinstance(field_schema, dict):
|
||||
continue
|
||||
|
||||
# Refresh model options for LLM model fields
|
||||
if model_class and hasattr(model_class, "model_fields"):
|
||||
field_info = model_class.model_fields.get(field_name)
|
||||
if field_info and is_llm_model_field(field_name, field_info):
|
||||
try:
|
||||
refresh_llm_model_options(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM options for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Refresh discriminator mapping for fields that use model discrimination
|
||||
try:
|
||||
refresh_llm_discriminator_mapping(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh discriminator mapping for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
@@ -40,7 +40,6 @@ from pydantic_core import (
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.llm_registry import update_schema_with_llm_registry
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.settings import Secrets
|
||||
@@ -545,9 +544,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
|
||||
# Ensure LLM discriminators are populated (delegates to shared helper)
|
||||
update_schema_with_llm_registry(schema, model_class)
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
@@ -696,20 +693,16 @@ def CredentialsField(
|
||||
This is enforced by the `BlockSchema` base class.
|
||||
"""
|
||||
|
||||
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
|
||||
field_schema_extra: dict[str, Any] = {}
|
||||
|
||||
# Always include discriminator if provided
|
||||
if discriminator is not None:
|
||||
field_schema_extra["discriminator"] = discriminator
|
||||
# Always include discriminator_mapping when discriminator is set (even if empty initially)
|
||||
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
|
||||
|
||||
# Include other optional fields (only if not None)
|
||||
if required_scopes:
|
||||
field_schema_extra["credentials_scopes"] = list(required_scopes)
|
||||
if discriminator_values:
|
||||
field_schema_extra["discriminator_values"] = discriminator_values
|
||||
field_schema_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Merge any json_schema_extra passed in kwargs
|
||||
if "json_schema_extra" in kwargs:
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
Helper functions for LLM registry initialization in executor context.
|
||||
|
||||
These functions handle refreshing the LLM registry when the executor starts
|
||||
and subscribing to real-time updates via Redis pub/sub.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.data import db, llm_registry
|
||||
from backend.data.block import BlockSchema, initialize_blocks
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.llm_registry import subscribe_to_registry_refresh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def initialize_registry_for_executor() -> None:
|
||||
"""
|
||||
Initialize blocks and refresh LLM registry in the executor context.
|
||||
|
||||
This must run in the executor's event loop to have access to the database.
|
||||
"""
|
||||
try:
|
||||
# Connect to database if not already connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
logger.info("[GraphExecutor] Connected to database for registry refresh")
|
||||
|
||||
# Initialize blocks (internally refreshes LLM registry and costs)
|
||||
await initialize_blocks()
|
||||
logger.info("[GraphExecutor] Blocks initialized")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_registry_on_notification() -> None:
|
||||
"""Refresh LLM registry when notified via Redis pub/sub."""
|
||||
try:
|
||||
# Ensure DB is connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
|
||||
# Refresh registry and costs
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they regenerate with new model options
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
logger.info("[GraphExecutor] LLM registry refreshed from notification")
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_updates() -> None:
|
||||
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
|
||||
await subscribe_to_registry_refresh(refresh_registry_on_notification)
|
||||
@@ -702,20 +702,6 @@ class ExecutionProcessor:
|
||||
)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
|
||||
# Initialize LLM registry and subscribe to updates
|
||||
from backend.executor.llm_registry_init import (
|
||||
initialize_registry_for_executor,
|
||||
subscribe_to_registry_updates,
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
initialize_registry_for_executor(), self.node_execution_loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
subscribe_to_registry_updates(), self.node_execution_loop
|
||||
)
|
||||
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
|
||||
@@ -1,935 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Sequence, cast
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
def _json_dict(value: Any | None) -> dict[str, Any]:
|
||||
if not value:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return {}
|
||||
|
||||
|
||||
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
|
||||
return llm_model.LlmModelCost(
|
||||
id=record.id,
|
||||
unit=record.unit,
|
||||
credit_cost=record.creditCost,
|
||||
credential_provider=record.credentialProvider,
|
||||
credential_id=record.credentialId,
|
||||
credential_type=record.credentialType,
|
||||
currency=record.currency,
|
||||
metadata=_json_dict(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
record: prisma.models.LlmModelCreator,
|
||||
) -> llm_model.LlmModelCreator:
|
||||
return llm_model.LlmModelCreator(
|
||||
id=record.id,
|
||||
name=record.name,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
website_url=record.websiteUrl,
|
||||
logo_url=record.logoUrl,
|
||||
metadata=_json_dict(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
|
||||
costs = []
|
||||
if record.Costs:
|
||||
costs = [_map_cost(cost) for cost in record.Costs]
|
||||
|
||||
creator = None
|
||||
if hasattr(record, "Creator") and record.Creator:
|
||||
creator = _map_creator(record.Creator)
|
||||
|
||||
return llm_model.LlmModel(
|
||||
id=record.id,
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
provider_id=record.providerId,
|
||||
creator_id=record.creatorId,
|
||||
creator=creator,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
capabilities=_json_dict(record.capabilities),
|
||||
metadata=_json_dict(record.metadata),
|
||||
costs=costs,
|
||||
)
|
||||
|
||||
|
||||
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
|
||||
models: list[llm_model.LlmModel] = []
|
||||
if record.Models:
|
||||
models = [_map_model(model) for model in record.Models]
|
||||
|
||||
return llm_model.LlmProvider(
|
||||
id=record.id,
|
||||
name=record.name,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
default_credential_provider=record.defaultCredentialProvider,
|
||||
default_credential_id=record.defaultCredentialId,
|
||||
default_credential_type=record.defaultCredentialType,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool=record.supportsParallelTool,
|
||||
metadata=_json_dict(record.metadata),
|
||||
models=models,
|
||||
)
|
||||
|
||||
|
||||
async def list_providers(
|
||||
include_models: bool = True, enabled_only: bool = False
|
||||
) -> list[llm_model.LlmProvider]:
|
||||
"""
|
||||
List all LLM providers.
|
||||
|
||||
Args:
|
||||
include_models: Whether to include models for each provider
|
||||
enabled_only: If True, only include enabled models (for public routes)
|
||||
"""
|
||||
include: Any = None
|
||||
if include_models:
|
||||
model_where = {"isEnabled": True} if enabled_only else None
|
||||
include = {
|
||||
"Models": {
|
||||
"include": {"Costs": True, "Creator": True},
|
||||
"where": model_where,
|
||||
}
|
||||
}
|
||||
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
|
||||
return [_map_provider(record) for record in records]
|
||||
|
||||
|
||||
async def upsert_provider(
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
provider_id: str | None = None,
|
||||
) -> llm_model.LlmProvider:
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"defaultCredentialProvider": request.default_credential_provider,
|
||||
"defaultCredentialId": request.default_credential_id,
|
||||
"defaultCredentialType": request.default_credential_type,
|
||||
"supportsTools": request.supports_tools,
|
||||
"supportsJsonOutput": request.supports_json_output,
|
||||
"supportsReasoning": request.supports_reasoning,
|
||||
"supportsParallelTool": request.supports_parallel_tool,
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
}
|
||||
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
|
||||
if provider_id:
|
||||
record = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include=include,
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include=include,
|
||||
)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update provider")
|
||||
return _map_provider(record)
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""
|
||||
Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
Due to onDelete: Restrict on LlmModel.Provider, the database will
|
||||
block deletion if models exist.
|
||||
|
||||
Args:
|
||||
provider_id: UUID of the provider to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found or has associated models
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
# Safe to delete
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def list_models(
|
||||
provider_id: str | None = None,
|
||||
enabled_only: bool = False,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> llm_model.LlmModelsResponse:
|
||||
"""
|
||||
List LLM models with pagination.
|
||||
|
||||
Args:
|
||||
provider_id: Optional filter by provider ID
|
||||
enabled_only: If True, only return enabled models (for public routes)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Number of models per page
|
||||
"""
|
||||
where: Any = {}
|
||||
if provider_id:
|
||||
where["providerId"] = provider_id
|
||||
if enabled_only:
|
||||
where["isEnabled"] = True
|
||||
|
||||
# Get total count for pagination
|
||||
total_items = await prisma.models.LlmModel.prisma().count(
|
||||
where=where if where else None
|
||||
)
|
||||
|
||||
# Calculate pagination
|
||||
skip = (page - 1) * page_size
|
||||
total_pages = (total_items + page_size - 1) // page_size if total_items > 0 else 0
|
||||
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where if where else None,
|
||||
include={"Costs": True, "Creator": True},
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
)
|
||||
models = [_map_model(record) for record in records]
|
||||
|
||||
return llm_model.LlmModelsResponse(
|
||||
models=models,
|
||||
pagination=Pagination(
|
||||
total_items=total_items,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _cost_create_payload(
|
||||
costs: Sequence[llm_model.LlmModelCostInput],
|
||||
) -> dict[str, Iterable[dict[str, Any]]]:
|
||||
|
||||
create_items = []
|
||||
for cost in costs:
|
||||
item: dict[str, Any] = {
|
||||
"unit": cost.unit,
|
||||
"creditCost": cost.credit_cost,
|
||||
"credentialProvider": cost.credential_provider,
|
||||
}
|
||||
# Only include optional fields if they have values
|
||||
if cost.credential_id:
|
||||
item["credentialId"] = cost.credential_id
|
||||
if cost.credential_type:
|
||||
item["credentialType"] = cost.credential_type
|
||||
if cost.currency:
|
||||
item["currency"] = cost.currency
|
||||
# Handle metadata - use Prisma Json type
|
||||
if cost.metadata is not None and cost.metadata != {}:
|
||||
item["metadata"] = prisma.Json(cost.metadata)
|
||||
create_items.append(item)
|
||||
return {"create": create_items}
|
||||
|
||||
|
||||
async def create_model(
|
||||
request: llm_model.CreateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
data: Any = {
|
||||
"slug": request.slug,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"Provider": {"connect": {"id": request.provider_id}},
|
||||
"contextWindow": request.context_window,
|
||||
"maxOutputTokens": request.max_output_tokens,
|
||||
"isEnabled": request.is_enabled,
|
||||
"capabilities": prisma.Json(request.capabilities or {}),
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
"Costs": _cost_create_payload(request.costs),
|
||||
}
|
||||
if request.creator_id:
|
||||
data["Creator"] = {"connect": {"id": request.creator_id}}
|
||||
|
||||
record = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
# Build scalar field updates (non-relation fields)
|
||||
scalar_data: Any = {}
|
||||
if request.display_name is not None:
|
||||
scalar_data["displayName"] = request.display_name
|
||||
if request.description is not None:
|
||||
scalar_data["description"] = request.description
|
||||
if request.context_window is not None:
|
||||
scalar_data["contextWindow"] = request.context_window
|
||||
if request.max_output_tokens is not None:
|
||||
scalar_data["maxOutputTokens"] = request.max_output_tokens
|
||||
if request.is_enabled is not None:
|
||||
scalar_data["isEnabled"] = request.is_enabled
|
||||
if request.capabilities is not None:
|
||||
scalar_data["capabilities"] = request.capabilities
|
||||
if request.metadata is not None:
|
||||
scalar_data["metadata"] = request.metadata
|
||||
# Foreign keys can be updated directly as scalar fields
|
||||
if request.provider_id is not None:
|
||||
scalar_data["providerId"] = request.provider_id
|
||||
if request.creator_id is not None:
|
||||
# Empty string means remove the creator
|
||||
scalar_data["creatorId"] = request.creator_id if request.creator_id else None
|
||||
|
||||
# If we have costs to update, we need to handle them separately
|
||||
# because nested writes have different constraints
|
||||
if request.costs is not None:
|
||||
# Wrap cost replacement in a transaction for atomicity
|
||||
async with transaction() as tx:
|
||||
# First update scalar fields
|
||||
if scalar_data:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=scalar_data,
|
||||
)
|
||||
# Then handle costs: delete existing and create new
|
||||
await tx.llmmodelcost.delete_many(where={"llmModelId": model_id})
|
||||
if request.costs:
|
||||
cost_payload = _cost_create_payload(request.costs)
|
||||
for cost_item in cost_payload["create"]:
|
||||
cost_item["llmModelId"] = model_id
|
||||
await tx.llmmodelcost.create(data=cast(Any, cost_item))
|
||||
# Fetch the updated record (outside transaction)
|
||||
record = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
else:
|
||||
# No costs update - simple update
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data=scalar_data,
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
|
||||
if not record:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def toggle_model(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> llm_model.ToggleLlmModelResponse:
|
||||
"""
|
||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to toggle
|
||||
is_enabled: New enabled status
|
||||
migrate_to_slug: If disabling and this is provided, migrate all workflows
|
||||
using this model to the specified replacement model
|
||||
migration_reason: Optional reason for the migration (e.g., "Provider outage")
|
||||
custom_credit_cost: Optional custom pricing override for migrated workflows.
|
||||
When set, the billing system should use this cost instead
|
||||
of the target model's cost for affected nodes.
|
||||
|
||||
Returns:
|
||||
ToggleLlmModelResponse with the updated model and optional migration stats
|
||||
"""
|
||||
import json
|
||||
|
||||
# Get the model being toggled
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
# If disabling with migration, perform migration first
|
||||
if not is_enabled and migrate_to_slug:
|
||||
# Validate replacement model exists and is enabled
|
||||
replacement = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(f"Replacement model '{migrate_to_slug}' not found")
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
# Perform all operations atomically within a single transaction
|
||||
# This ensures no nodes are missed between query and update
|
||||
async with transaction() as tx:
|
||||
# Get the IDs of nodes that will be migrated (inside transaction for consistency)
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
# Update by IDs to ensure we only update the exact nodes we queried
|
||||
# Use JSON array and jsonb_array_elements_text for safe parameterization
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
migrate_to_slug,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
record = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
include={"Costs": True},
|
||||
)
|
||||
|
||||
# Create migration record for revert capability
|
||||
if nodes_migrated > 0:
|
||||
migration_data: Any = {
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data=migration_data
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
# Simple toggle without migration
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
include={"Costs": True},
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return llm_model.ToggleLlmModelResponse(
|
||||
model=_map_model(record),
|
||||
nodes_migrated=nodes_migrated,
|
||||
migrated_to_slug=migrate_to_slug if nodes_migrated > 0 else None,
|
||||
migration_id=migration_id,
|
||||
)
|
||||
|
||||
|
||||
async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
|
||||
"""Get usage count for a model."""
|
||||
import prisma as prisma_module
|
||||
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> llm_model.DeleteLlmModelResponse:
|
||||
"""
|
||||
Delete a model and optionally migrate all AgentNodes using it to a replacement model.
|
||||
|
||||
This performs an atomic operation within a database transaction:
|
||||
1. Validates the model exists
|
||||
2. Counts affected nodes
|
||||
3. If nodes exist, validates replacement model and migrates them
|
||||
4. Deletes the LlmModel record (CASCADE deletes costs)
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to delete
|
||||
replacement_model_slug: Slug of the model to migrate to (required only if nodes use this model)
|
||||
|
||||
Returns:
|
||||
DeleteLlmModelResponse with migration stats
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found, nodes exist but no replacement provided,
|
||||
replacement not found, or replacement is disabled
|
||||
"""
|
||||
# 1. Get the model being deleted (validation - outside transaction)
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
# 2. Count affected nodes first to determine if replacement is needed
|
||||
import prisma as prisma_module
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
# 3. Validate replacement model only if there are nodes to migrate
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(f"Replacement model '{replacement_model_slug}' not found")
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
# 4. Perform migration (if needed) and deletion atomically within a transaction
|
||||
async with transaction() as tx:
|
||||
# Migrate all AgentNode.constantInput->model to replacement
|
||||
if nodes_to_migrate > 0 and replacement_model_slug:
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
# Delete the model (CASCADE will delete costs automatically)
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
# Build appropriate message based on whether migration happened
|
||||
if nodes_to_migrate > 0:
|
||||
message = (
|
||||
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
|
||||
f"and migrated {nodes_to_migrate} workflow node(s) to '{replacement_model_slug}'."
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}). "
|
||||
f"No workflows were using this model."
|
||||
)
|
||||
|
||||
return llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug=deleted_slug,
|
||||
deleted_model_display_name=deleted_display_name,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
nodes_migrated=nodes_to_migrate,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
def _map_migration(
|
||||
record: prisma.models.LlmModelMigration,
|
||||
) -> llm_model.LlmModelMigration:
|
||||
return llm_model.LlmModelMigration(
|
||||
id=record.id,
|
||||
source_model_slug=record.sourceModelSlug,
|
||||
target_model_slug=record.targetModelSlug,
|
||||
reason=record.reason,
|
||||
node_count=record.nodeCount,
|
||||
custom_credit_cost=record.customCreditCost,
|
||||
is_reverted=record.isReverted,
|
||||
created_at=record.createdAt.isoformat(),
|
||||
reverted_at=record.revertedAt.isoformat() if record.revertedAt else None,
|
||||
)
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[llm_model.LlmModelMigration]:
|
||||
"""
|
||||
List model migrations, optionally including reverted ones.
|
||||
|
||||
Args:
|
||||
include_reverted: If True, include reverted migrations. Default is False.
|
||||
|
||||
Returns:
|
||||
List of LlmModelMigration records
|
||||
"""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [_map_migration(record) for record in records]
|
||||
|
||||
|
||||
async def get_migration(migration_id: str) -> llm_model.LlmModelMigration | None:
|
||||
"""Get a specific migration by ID."""
|
||||
record = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
return _map_migration(record) if record else None
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> llm_model.RevertMigrationResponse:
|
||||
"""
|
||||
Revert a model migration, restoring affected nodes to their original model.
|
||||
|
||||
This only reverts the specific nodes that were migrated, not all nodes
|
||||
currently using the target model.
|
||||
|
||||
Args:
|
||||
migration_id: UUID of the migration to revert
|
||||
re_enable_source_model: Whether to re-enable the source model if it's disabled
|
||||
|
||||
Returns:
|
||||
RevertMigrationResponse with revert stats
|
||||
|
||||
Raises:
|
||||
ValueError: If migration not found, already reverted, or source model not available
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Get the migration record
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted "
|
||||
f"on {migration.revertedAt.isoformat() if migration.revertedAt else 'unknown date'}"
|
||||
)
|
||||
|
||||
# Check if source model exists
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists. "
|
||||
f"Cannot revert migration."
|
||||
)
|
||||
|
||||
# Get the migrated node IDs (Prisma auto-parses JSONB to list)
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
# Track if we need to re-enable the source model
|
||||
source_model_was_disabled = not source_model.isEnabled
|
||||
should_re_enable = source_model_was_disabled and re_enable_source_model
|
||||
source_model_re_enabled = False
|
||||
|
||||
# Perform revert atomically
|
||||
async with transaction() as tx:
|
||||
# Re-enable the source model if requested and it was disabled
|
||||
if should_re_enable:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
# Update only the specific nodes that were migrated
|
||||
# We need to check that they still have the target model (haven't been changed since)
|
||||
# Use a single batch update for efficiency
|
||||
# Use JSON array and jsonb_array_elements_text for safe parameterization
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if result else 0
|
||||
|
||||
# Mark migration as reverted
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
# Calculate nodes that were already changed since migration
|
||||
nodes_already_changed = len(migrated_node_ids) - nodes_reverted
|
||||
|
||||
# Build appropriate message
|
||||
message_parts = [
|
||||
f"Successfully reverted migration: {nodes_reverted} node(s) restored "
|
||||
f"from '{migration.targetModelSlug}' to '{migration.sourceModelSlug}'."
|
||||
]
|
||||
if nodes_already_changed > 0:
|
||||
message_parts.append(
|
||||
f" {nodes_already_changed} node(s) were already changed and not reverted."
|
||||
)
|
||||
if source_model_re_enabled:
|
||||
message_parts.append(
|
||||
f" Model '{migration.sourceModelSlug}' has been re-enabled."
|
||||
)
|
||||
|
||||
return llm_model.RevertMigrationResponse(
|
||||
migration_id=migration_id,
|
||||
source_model_slug=migration.sourceModelSlug,
|
||||
target_model_slug=migration.targetModelSlug,
|
||||
nodes_reverted=nodes_reverted,
|
||||
nodes_already_changed=nodes_already_changed,
|
||||
source_model_re_enabled=source_model_re_enabled,
|
||||
message="".join(message_parts),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creator CRUD operations
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def list_creators() -> list[llm_model.LlmModelCreator]:
|
||||
"""List all LLM model creators."""
|
||||
records = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"displayName": "asc"}
|
||||
)
|
||||
return [_map_creator(record) for record in records]
|
||||
|
||||
|
||||
async def get_creator(creator_id: str) -> llm_model.LlmModelCreator | None:
|
||||
"""Get a specific creator by ID."""
|
||||
record = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"id": creator_id}
|
||||
)
|
||||
return _map_creator(record) if record else None
|
||||
|
||||
|
||||
async def upsert_creator(
|
||||
request: llm_model.UpsertLlmCreatorRequest,
|
||||
creator_id: str | None = None,
|
||||
) -> llm_model.LlmModelCreator:
|
||||
"""Create or update a model creator."""
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"websiteUrl": request.website_url,
|
||||
"logoUrl": request.logo_url,
|
||||
"metadata": prisma.Json(request.metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
record = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": creator_id},
|
||||
data=data,
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update creator")
|
||||
return _map_creator(record)
|
||||
|
||||
|
||||
async def delete_creator(creator_id: str) -> bool:
|
||||
"""
|
||||
Delete a model creator.
|
||||
|
||||
This will set creatorId to NULL on all associated models (due to onDelete: SetNull).
|
||||
|
||||
Args:
|
||||
creator_id: UUID of the creator to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ValueError: If creator not found
|
||||
"""
|
||||
creator = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"id": creator_id}
|
||||
)
|
||||
if not creator:
|
||||
raise ValueError(f"Creator with id '{creator_id}' not found")
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(where={"id": creator_id})
|
||||
return True
|
||||
|
||||
|
||||
async def get_recommended_model() -> llm_model.LlmModel | None:
|
||||
"""
|
||||
Get the currently recommended LLM model.
|
||||
|
||||
Returns:
|
||||
The recommended model, or None if no model is marked as recommended.
|
||||
"""
|
||||
record = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True, "isEnabled": True},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model(record) if record else None
|
||||
|
||||
|
||||
async def set_recommended_model(
|
||||
model_id: str,
|
||||
) -> tuple[llm_model.LlmModel, str | None]:
|
||||
"""
|
||||
Set a model as the recommended model.
|
||||
|
||||
This will clear the isRecommended flag from any other model and set it
|
||||
on the specified model. The model must be enabled.
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to set as recommended
|
||||
|
||||
Returns:
|
||||
Tuple of (the updated model, previous recommended model slug or None)
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found or not enabled
|
||||
"""
|
||||
# First, verify the model exists and is enabled
|
||||
target_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}
|
||||
)
|
||||
if not target_model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
if not target_model.isEnabled:
|
||||
raise ValueError(
|
||||
f"Cannot set disabled model '{target_model.slug}' as recommended"
|
||||
)
|
||||
|
||||
# Get the current recommended model (if any)
|
||||
current_recommended = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True}
|
||||
)
|
||||
previous_slug = current_recommended.slug if current_recommended else None
|
||||
|
||||
# Use a transaction to ensure atomicity
|
||||
async with transaction() as tx:
|
||||
# Clear isRecommended from all models
|
||||
await tx.llmmodel.update_many(
|
||||
where={"isRecommended": True},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
# Set the new recommended model
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isRecommended": True},
|
||||
)
|
||||
|
||||
# Fetch and return the updated model
|
||||
updated_record = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
if not updated_record:
|
||||
raise ValueError("Failed to fetch updated model")
|
||||
|
||||
return _map_model(updated_record), previous_slug
|
||||
|
||||
|
||||
async def get_recommended_model_slug() -> str | None:
|
||||
"""
|
||||
Get the slug of the currently recommended LLM model.
|
||||
|
||||
Returns:
|
||||
The slug of the recommended model, or None if no model is marked as recommended.
|
||||
"""
|
||||
record = await prisma.models.LlmModel.prisma().find_first(
|
||||
where={"isRecommended": True, "isEnabled": True},
|
||||
)
|
||||
return record.slug if record else None
|
||||
@@ -1,235 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
# Pattern for valid model slugs: alphanumeric start, then alphanumeric, dots, underscores, slashes, hyphens
|
||||
SLUG_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$")
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
id: str
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = None
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model (e.g., OpenAI, Meta)."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
id: str
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
creator_id: Optional[str] = None
|
||||
creator: Optional[LlmModelCreator] = None
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = None
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
providers: list[LlmProvider]
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
models: list[LlmModel]
|
||||
pagination: Optional[Pagination] = None
|
||||
|
||||
|
||||
class LlmCreatorsResponse(pydantic.BaseModel):
|
||||
creators: list[LlmModelCreator]
|
||||
|
||||
|
||||
class UpsertLlmProviderRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = "api_key"
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class UpsertLlmCreatorRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
website_url: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCostInput(pydantic.BaseModel):
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = "api_key"
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class CreateLlmModelRequest(pydantic.BaseModel):
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
creator_id: Optional[str] = None
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCostInput]
|
||||
|
||||
@pydantic.field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str) -> str:
|
||||
if not v or len(v) > 100:
|
||||
raise ValueError("Slug must be 1-100 characters")
|
||||
if not SLUG_PATTERN.match(v):
|
||||
raise ValueError(
|
||||
"Slug must start with alphanumeric and contain only "
|
||||
"alphanumeric characters, dots, underscores, slashes, or hyphens"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(pydantic.BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
context_window: Optional[int] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
capabilities: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
provider_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
costs: Optional[list[LlmModelCostInput]] = None
|
||||
|
||||
|
||||
class ToggleLlmModelRequest(pydantic.BaseModel):
|
||||
is_enabled: bool
|
||||
migrate_to_slug: Optional[str] = None
|
||||
migration_reason: Optional[str] = None # e.g., "Provider outage"
|
||||
# Custom pricing override for migrated workflows. When set, billing should use
|
||||
# this cost instead of the target model's cost for affected nodes.
|
||||
# See LlmModelMigration in schema.prisma for full documentation.
|
||||
custom_credit_cost: Optional[int] = None
|
||||
|
||||
|
||||
class ToggleLlmModelResponse(pydantic.BaseModel):
|
||||
model: LlmModel
|
||||
nodes_migrated: int = 0
|
||||
migrated_to_slug: Optional[str] = None
|
||||
migration_id: Optional[str] = None # ID of the migration record for revert
|
||||
|
||||
|
||||
class DeleteLlmModelResponse(pydantic.BaseModel):
|
||||
deleted_model_slug: str
|
||||
deleted_model_display_name: str
|
||||
replacement_model_slug: Optional[str] = None
|
||||
nodes_migrated: int
|
||||
message: str
|
||||
|
||||
|
||||
class LlmModelUsageResponse(pydantic.BaseModel):
|
||||
model_slug: str
|
||||
node_count: int
|
||||
|
||||
|
||||
# Migration tracking models
|
||||
class LlmModelMigration(pydantic.BaseModel):
|
||||
id: str
|
||||
source_model_slug: str
|
||||
target_model_slug: str
|
||||
reason: Optional[str] = None
|
||||
node_count: int
|
||||
# Custom pricing override - billing should use this instead of target model's cost
|
||||
custom_credit_cost: Optional[int] = None
|
||||
is_reverted: bool = False
|
||||
created_at: datetime
|
||||
reverted_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class LlmMigrationsResponse(pydantic.BaseModel):
|
||||
migrations: list[LlmModelMigration]
|
||||
|
||||
|
||||
class RevertMigrationRequest(pydantic.BaseModel):
|
||||
re_enable_source_model: bool = (
|
||||
True # Whether to re-enable the source model if disabled
|
||||
)
|
||||
|
||||
|
||||
class RevertMigrationResponse(pydantic.BaseModel):
|
||||
migration_id: str
|
||||
source_model_slug: str
|
||||
target_model_slug: str
|
||||
nodes_reverted: int
|
||||
nodes_already_changed: int = (
|
||||
0 # Nodes that were modified since migration (not reverted)
|
||||
)
|
||||
source_model_re_enabled: bool = False # Whether the source model was re-enabled
|
||||
message: str
|
||||
|
||||
|
||||
class SetRecommendedModelRequest(pydantic.BaseModel):
|
||||
model_id: str
|
||||
|
||||
|
||||
class SetRecommendedModelResponse(pydantic.BaseModel):
|
||||
model: LlmModel
|
||||
previous_recommended_slug: Optional[str] = None
|
||||
message: str
|
||||
|
||||
|
||||
class RecommendedModelResponse(pydantic.BaseModel):
|
||||
model: Optional[LlmModel] = None
|
||||
slug: Optional[str] = None
|
||||
@@ -1,29 +0,0 @@
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=50, ge=1, le=100, description="Number of models per page"
|
||||
),
|
||||
):
|
||||
"""List all enabled LLM models available to users."""
|
||||
return await llm_db.list_models(enabled_only=True, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
"""List all LLM providers with their enabled models."""
|
||||
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
@@ -350,6 +350,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Whether to mark failed scans as clean or not",
|
||||
)
|
||||
|
||||
agentgenerator_host: str = Field(
|
||||
default="",
|
||||
description="The host for the Agent Generator service (empty to use built-in)",
|
||||
)
|
||||
agentgenerator_port: int = Field(
|
||||
default=8000,
|
||||
description="The port for the Agent Generator service",
|
||||
)
|
||||
agentgenerator_timeout: int = Field(
|
||||
default=120,
|
||||
description="The timeout in seconds for Agent Generator service requests",
|
||||
)
|
||||
|
||||
enable_example_blocks: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable example blocks in production",
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCost_llmModelId_credentialProvider_unit_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models
|
||||
('o3', 'O3', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1000000, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
|
||||
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
|
||||
-- Anthropic models
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
|
||||
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
|
||||
-- Groq models
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
|
||||
-- Ollama models
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
|
||||
-- OpenRouter models
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
|
||||
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
|
||||
-- Llama API models
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
|
||||
-- v0 models
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
('gpt-3.5-turbo', 1),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-7-sonnet-20250219', 5),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-3-pro-preview', 5),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId"
|
||||
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_idx" ON "LlmModelMigration"("sourceModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_isReverted_idx" ON "LlmModelMigration"("isReverted");
|
||||
@@ -1,127 +0,0 @@
|
||||
-- Add LlmModelCreator table
|
||||
-- Creator represents who made/trained the model (e.g., OpenAI, Meta)
|
||||
-- This is distinct from Provider who hosts/serves the model (e.g., OpenRouter)
|
||||
|
||||
-- Create the LlmModelCreator table
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- Create unique index on name
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- Add creatorId column to LlmModel
|
||||
ALTER TABLE "LlmModel" ADD COLUMN "creatorId" TEXT;
|
||||
|
||||
-- Add foreign key constraint
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
|
||||
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- Create index on creatorId
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- Seed creators based on known model creators
|
||||
INSERT INTO "LlmModelCreator" ("id", "updatedAt", "name", "displayName", "description", "websiteUrl", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT models', 'https://openai.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude models', 'https://anthropic.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama models', 'https://ai.meta.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini models', 'https://deepmind.google', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'mistral', 'Mistral AI', 'Creator of Mistral models', 'https://mistral.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command models', 'https://cohere.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek models', 'https://deepseek.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar models', 'https://perplexity.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'qwen', 'Qwen (Alibaba)', 'Creator of Qwen models', 'https://qwenlm.github.io', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova models', 'https://aws.amazon.com/bedrock', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of WizardLM models', 'https://microsoft.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'moonshot', 'Moonshot AI', 'Creator of Kimi models', 'https://moonshot.cn', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nous_research', 'Nous Research', 'Creator of Hermes models', 'https://nousresearch.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 models', 'https://vercel.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cognitive_computations', 'Cognitive Computations', 'Creator of Dolphin models', 'https://erichartford.com', '{}'),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', '{}')
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Update existing models with their creators
|
||||
-- OpenAI models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'openai')
|
||||
WHERE "slug" LIKE 'gpt-%' OR "slug" LIKE 'o1%' OR "slug" LIKE 'o3%' OR "slug" LIKE 'openai/%';
|
||||
|
||||
-- Anthropic models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'anthropic')
|
||||
WHERE "slug" LIKE 'claude-%';
|
||||
|
||||
-- Meta/Llama models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'meta')
|
||||
WHERE "slug" LIKE 'llama%' OR "slug" LIKE 'Llama%' OR "slug" LIKE 'meta-llama/%' OR "slug" LIKE '%/llama-%';
|
||||
|
||||
-- Google models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'google')
|
||||
WHERE "slug" LIKE 'google/%' OR "slug" LIKE 'gemini%';
|
||||
|
||||
-- Mistral models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'mistral')
|
||||
WHERE "slug" LIKE 'mistral%' OR "slug" LIKE 'mistralai/%';
|
||||
|
||||
-- Cohere models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cohere')
|
||||
WHERE "slug" LIKE 'cohere/%' OR "slug" LIKE 'command-%';
|
||||
|
||||
-- DeepSeek models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'deepseek')
|
||||
WHERE "slug" LIKE 'deepseek/%' OR "slug" LIKE 'deepseek-%';
|
||||
|
||||
-- Perplexity models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'perplexity')
|
||||
WHERE "slug" LIKE 'perplexity/%' OR "slug" LIKE 'sonar%';
|
||||
|
||||
-- Qwen models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'qwen')
|
||||
WHERE "slug" LIKE 'Qwen/%' OR "slug" LIKE 'qwen/%';
|
||||
|
||||
-- xAI/Grok models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'xai')
|
||||
WHERE "slug" LIKE 'x-ai/%' OR "slug" LIKE 'grok%';
|
||||
|
||||
-- Amazon models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'amazon')
|
||||
WHERE "slug" LIKE 'amazon/%' OR "slug" LIKE 'nova-%';
|
||||
|
||||
-- Microsoft models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'microsoft')
|
||||
WHERE "slug" LIKE 'microsoft/%' OR "slug" LIKE 'wizardlm%';
|
||||
|
||||
-- Moonshot models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'moonshot')
|
||||
WHERE "slug" LIKE 'moonshotai/%' OR "slug" LIKE 'kimi%';
|
||||
|
||||
-- NVIDIA models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nvidia')
|
||||
WHERE "slug" LIKE 'nvidia/%' OR "slug" LIKE '%nemotron%';
|
||||
|
||||
-- Nous Research models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nous_research')
|
||||
WHERE "slug" LIKE 'nousresearch/%' OR "slug" LIKE 'hermes%';
|
||||
|
||||
-- Vercel/v0 models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'vercel')
|
||||
WHERE "slug" LIKE 'v0-%';
|
||||
|
||||
-- Dolphin models (Cognitive Computations / Eric Hartford)
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cognitive_computations')
|
||||
WHERE "slug" LIKE 'dolphin-%';
|
||||
|
||||
-- Gryphe models
|
||||
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'gryphe')
|
||||
WHERE "slug" LIKE 'gryphe/%' OR "slug" LIKE 'mythomax%';
|
||||
@@ -1,4 +0,0 @@
|
||||
-- CreateIndex
|
||||
-- Index for efficient LLM model lookups on AgentNode.constantInput->>'model'
|
||||
-- This improves performance of model migration queries in the LLM registry
|
||||
CREATE INDEX "AgentNode_constantInput_model_idx" ON "AgentNode" ((("constantInput"->>'model')));
|
||||
@@ -1,52 +0,0 @@
|
||||
-- Add GPT-5.2 model and update O3 slug
|
||||
-- This migration adds the new GPT-5.2 model added in dev branch
|
||||
|
||||
-- Update O3 slug to match dev branch format
|
||||
UPDATE "LlmModel"
|
||||
SET "slug" = 'o3-2025-04-16'
|
||||
WHERE "slug" = 'o3';
|
||||
|
||||
-- Update cost reference for O3 if needed
|
||||
-- (costs are linked by model ID, so no update needed)
|
||||
|
||||
-- Add GPT-5.2 model
|
||||
WITH provider_id AS (
|
||||
SELECT "id" FROM "LlmProvider" WHERE "name" = 'openai'
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'gpt-5.2-2025-12-11',
|
||||
'GPT 5.2',
|
||||
'OpenAI GPT-5.2 model',
|
||||
p."id",
|
||||
400000,
|
||||
128000,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM provider_id p
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Add cost for GPT-5.2
|
||||
WITH model_id AS (
|
||||
SELECT m."id", p."name" as provider_name
|
||||
FROM "LlmModel" m
|
||||
JOIN "LlmProvider" p ON p."id" = m."providerId"
|
||||
WHERE m."slug" = 'gpt-5.2-2025-12-11'
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'RUN'::"LlmCostUnit",
|
||||
3, -- Same cost tier as GPT-5.1
|
||||
m.provider_name,
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM model_id m
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM "LlmModelCost" c WHERE c."llmModelId" = m."id"
|
||||
);
|
||||
@@ -1,11 +0,0 @@
|
||||
-- Add isRecommended field to LlmModel table
|
||||
-- This allows admins to mark a model as the recommended default
|
||||
|
||||
ALTER TABLE "LlmModel" ADD COLUMN "isRecommended" BOOLEAN NOT NULL DEFAULT false;
|
||||
|
||||
-- Set gpt-4o-mini as the default recommended model (if it exists)
|
||||
UPDATE "LlmModel" SET "isRecommended" = true WHERE "slug" = 'gpt-4o-mini' AND "isEnabled" = true;
|
||||
|
||||
-- Create unique partial index to enforce only one recommended model at the database level
|
||||
-- This prevents multiple rows from having isRecommended = true
|
||||
CREATE UNIQUE INDEX "LlmModel_single_recommended_idx" ON "LlmModel" ("isRecommended") WHERE "isRecommended" = true;
|
||||
@@ -1,37 +1,12 @@
|
||||
-- CreateExtension
|
||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||
-- Ensures vector extension is in the current schema (from DATABASE_URL ?schema= param)
|
||||
-- If it exists in a different schema (e.g., public), we drop and recreate it in the current schema
|
||||
-- Creates extension in current schema (determined by search_path from DATABASE_URL ?schema= param)
|
||||
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
||||
DO $$
|
||||
DECLARE
|
||||
current_schema_name text;
|
||||
vector_schema text;
|
||||
BEGIN
|
||||
-- Get the current schema from search_path
|
||||
SELECT current_schema() INTO current_schema_name;
|
||||
|
||||
-- Check if vector extension exists and which schema it's in
|
||||
SELECT n.nspname INTO vector_schema
|
||||
FROM pg_extension e
|
||||
JOIN pg_namespace n ON e.extnamespace = n.oid
|
||||
WHERE e.extname = 'vector';
|
||||
|
||||
-- Handle removal if in wrong schema
|
||||
IF vector_schema IS NOT NULL AND vector_schema != current_schema_name THEN
|
||||
BEGIN
|
||||
-- Vector exists in a different schema, drop it first
|
||||
RAISE WARNING 'pgvector found in schema "%" but need it in "%". Dropping and reinstalling...',
|
||||
vector_schema, current_schema_name;
|
||||
EXECUTE 'DROP EXTENSION IF EXISTS vector CASCADE';
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE EXCEPTION 'Failed to drop pgvector from schema "%": %. You may need to drop it manually.',
|
||||
vector_schema, SQLERRM;
|
||||
END;
|
||||
END IF;
|
||||
|
||||
-- Create extension in current schema (let it fail naturally if not available)
|
||||
EXECUTE format('CREATE EXTENSION IF NOT EXISTS vector SCHEMA %I', current_schema_name);
|
||||
CREATE EXTENSION IF NOT EXISTS "vector";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||
END $$;
|
||||
|
||||
-- CreateEnum
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||
-- These extensions are pre-installed by Supabase in specific schemas
|
||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||
|
||||
-- Create schemas (safe in both CI and Supabase)
|
||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||
|
||||
-- Extensions that exist in both CI and Supabase
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
-- Supabase-specific extensions (skip gracefully in CI)
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
|
||||
-- Return to platform
|
||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||
@@ -1,61 +0,0 @@
|
||||
-- Add new columns to LlmModel table for extended model metadata
|
||||
-- These columns support the LLM Picker UI enhancements
|
||||
|
||||
-- Add priceTier column: 1=cheapest, 2=medium, 3=expensive
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "priceTier" INTEGER NOT NULL DEFAULT 1;
|
||||
|
||||
-- Add creatorId column for model creator relationship (if not exists)
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "creatorId" TEXT;
|
||||
|
||||
-- Add isRecommended column (if not exists)
|
||||
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "isRecommended" BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add index on creatorId if not exists
|
||||
CREATE INDEX IF NOT EXISTS "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- Add foreign key for creatorId if not exists
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'LlmModel_creatorId_fkey') THEN
|
||||
-- Only add FK if LlmModelCreator table exists
|
||||
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'LlmModelCreator') THEN
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
|
||||
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
END IF;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- Update priceTier values for existing models based on original MODEL_METADATA
|
||||
-- Tier 1 = cheapest, Tier 2 = medium, Tier 3 = expensive
|
||||
|
||||
-- OpenAI models
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o3';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'o3-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'o1';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o1-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-5.2';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5.1';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-nano';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5-chat-latest';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" LIKE 'gpt-4.1%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-4o-mini';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-4o';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-4-turbo';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-3.5-turbo';
|
||||
|
||||
-- Anthropic models
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude-opus%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude-sonnet%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude%-4-5-sonnet%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude%-haiku%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'claude-3-haiku-20240307';
|
||||
|
||||
-- OpenRouter models - Pro/expensive tiers
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'google/gemini%-pro%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%command-r-plus%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%sonar-pro%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%sonar-deep-research%';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'x-ai/grok-4';
|
||||
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%qwen3-coder%';
|
||||
@@ -1096,153 +1096,6 @@ enum APIKeyStatus {
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
///////////// LLM REGISTRY AND BILLING DATA /////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// LlmCostUnit: Defines how LLM MODEL costs are calculated (per run or per token).
|
||||
// This is distinct from BlockCostType (in backend/data/block.py) which defines
|
||||
// how BLOCK EXECUTION costs are calculated (per run, per byte, or per second).
|
||||
// LlmCostUnit is for pricing individual LLM model API calls in the registry,
|
||||
// while BlockCostType is for billing platform block executions.
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
supportsTools Boolean @default(true)
|
||||
supportsJsonOutput Boolean @default(true)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelTool Boolean @default(false)
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
@@index([slug])
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int
|
||||
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([llmModelId, credentialProvider, unit])
|
||||
@@index([llmModelId])
|
||||
@@index([credentialProvider])
|
||||
}
|
||||
|
||||
// Tracks model migrations for revert capability
|
||||
// When a model is disabled with migration, we record which nodes were affected
|
||||
// so they can be reverted when the original model is back online
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
customCreditCost Int?
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
@@index([sourceModelSlug])
|
||||
@@index([targetModelSlug])
|
||||
@@index([isReverted])
|
||||
}
|
||||
////////////// OAUTH PROVIDER TABLES //////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for agent generator module."""
|
||||
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Tests for the Agent Generator core module.
|
||||
|
||||
This test suite verifies that the core functions correctly delegate to
|
||||
the external Agent Generator service.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.api.features.chat.tools.agent_generator.core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
)
|
||||
|
||||
|
||||
class TestServiceNotConfigured:
|
||||
"""Test that functions raise AgentGeneratorNotConfiguredError when service is not configured."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_raises_when_not_configured(self):
|
||||
"""Test that decompose_goal raises error when service not configured."""
|
||||
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||
await core.decompose_goal("Build a chatbot")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_raises_when_not_configured(self):
|
||||
"""Test that generate_agent raises error when service not configured."""
|
||||
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||
await core.generate_agent({"steps": []})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_patch_raises_when_not_configured(self):
|
||||
"""Test that generate_agent_patch raises error when service not configured."""
|
||||
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||
await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||
|
||||
|
||||
class TestDecomposeGoal:
|
||||
"""Test decompose_goal function service delegation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_external_service(self):
|
||||
"""Test that decompose_goal calls the external service."""
|
||||
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "decompose_goal_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
result = await core.decompose_goal("Build a chatbot")
|
||||
|
||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_context_to_external_service(self):
|
||||
"""Test that decompose_goal passes context to external service."""
|
||||
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "decompose_goal_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_service_failure(self):
|
||||
"""Test that decompose_goal returns None when external service fails."""
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "decompose_goal_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = None
|
||||
|
||||
result = await core.decompose_goal("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgent:
|
||||
"""Test generate_agent function service delegation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_external_service(self):
|
||||
"""Test that generate_agent calls the external service."""
|
||||
expected_result = {"name": "Test Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await core.generate_agent(instructions)
|
||||
|
||||
mock_external.assert_called_once_with(instructions)
|
||||
# Result should have id, version, is_active added if not present
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Agent"
|
||||
assert "id" in result
|
||||
assert result["version"] == 1
|
||||
assert result["is_active"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_existing_id_and_version(self):
|
||||
"""Test that external service result preserves existing id and version."""
|
||||
expected_result = {
|
||||
"id": "existing-id",
|
||||
"version": 3,
|
||||
"is_active": False,
|
||||
"name": "Test Agent",
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result.copy()
|
||||
|
||||
result = await core.generate_agent({"steps": []})
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == "existing-id"
|
||||
assert result["version"] == 3
|
||||
assert result["is_active"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_external_service_fails(self):
|
||||
"""Test that generate_agent returns None when external service fails."""
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = None
|
||||
|
||||
result = await core.generate_agent({"steps": []})
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgentPatch:
|
||||
"""Test generate_agent_patch function service delegation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_external_service(self):
|
||||
"""Test that generate_agent_patch calls the external service."""
|
||||
expected_result = {"name": "Updated Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
current_agent = {"nodes": [], "links": []}
|
||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||
|
||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_clarifying_questions(self):
|
||||
"""Test that generate_agent_patch returns clarifying questions."""
|
||||
expected_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [{"question": "What type of node?"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = expected_result
|
||||
|
||||
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_external_service_fails(self):
|
||||
"""Test that generate_agent_patch returns None when service fails."""
|
||||
with patch.object(
|
||||
core, "is_external_service_configured", return_value=True
|
||||
), patch.object(
|
||||
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||
) as mock_external:
|
||||
mock_external.return_value = None
|
||||
|
||||
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestJsonToGraph:
|
||||
"""Test json_to_graph function."""
|
||||
|
||||
def test_converts_agent_json_to_graph(self):
|
||||
"""Test conversion of agent JSON to Graph model."""
|
||||
agent_json = {
|
||||
"id": "test-id",
|
||||
"version": 2,
|
||||
"is_active": True,
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"block_id": "block1",
|
||||
"input_default": {"key": "value"},
|
||||
"metadata": {"x": 100},
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "link1",
|
||||
"source_id": "node1",
|
||||
"sink_id": "output",
|
||||
"source_name": "result",
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
graph = core.json_to_graph(agent_json)
|
||||
|
||||
assert graph.id == "test-id"
|
||||
assert graph.version == 2
|
||||
assert graph.is_active is True
|
||||
assert graph.name == "Test Agent"
|
||||
assert graph.description == "A test agent"
|
||||
assert len(graph.nodes) == 1
|
||||
assert graph.nodes[0].id == "node1"
|
||||
assert graph.nodes[0].block_id == "block1"
|
||||
assert len(graph.links) == 1
|
||||
assert graph.links[0].source_id == "node1"
|
||||
|
||||
def test_generates_ids_if_missing(self):
|
||||
"""Test that missing IDs are generated."""
|
||||
agent_json = {
|
||||
"name": "Test Agent",
|
||||
"nodes": [{"block_id": "block1"}],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
graph = core.json_to_graph(agent_json)
|
||||
|
||||
assert graph.id is not None
|
||||
assert graph.nodes[0].id is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
422
autogpt_platform/backend/test/agent_generator/test_service.py
Normal file
422
autogpt_platform/backend/test/agent_generator/test_service.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Tests for the Agent Generator external service client.
|
||||
|
||||
This test suite verifies the external Agent Generator service integration,
|
||||
including service detection, API calls, and error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import service
|
||||
|
||||
|
||||
class TestServiceConfiguration:
|
||||
"""Test service configuration detection."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset settings singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
def test_external_service_not_configured_when_host_empty(self):
|
||||
"""Test that external service is not configured when host is empty."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.config.agentgenerator_host = ""
|
||||
|
||||
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||
assert service.is_external_service_configured() is False
|
||||
|
||||
def test_external_service_configured_when_host_set(self):
|
||||
"""Test that external service is configured when host is set."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||
|
||||
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||
assert service.is_external_service_configured() is True
|
||||
|
||||
def test_get_base_url(self):
|
||||
"""Test base URL construction."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||
mock_settings.config.agentgenerator_port = 8000
|
||||
|
||||
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||
url = service._get_base_url()
|
||||
assert url == "http://agent-generator.local:8000"
|
||||
|
||||
|
||||
class TestDecomposeGoalExternal:
|
||||
"""Test decompose_goal_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_instructions(self):
|
||||
"""Test successful decomposition returning instructions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1", "Step 2"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description", json={"description": "Build a chatbot"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_clarifying_questions(self):
|
||||
"""Test decomposition returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build something")
|
||||
|
||||
assert result == {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_with_context(self):
|
||||
"""Test decomposition with additional context."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external(
|
||||
"Build a chatbot", context="Use Python"
|
||||
)
|
||||
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_unachievable_goal(self):
|
||||
"""Test decomposition returning unachievable goal response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Do something impossible")
|
||||
|
||||
assert result == {
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_http_error(self):
|
||||
"""Test decomposition handles HTTP errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=MagicMock(), response=MagicMock()
|
||||
)
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_request_error(self):
|
||||
"""Test decomposition handles request errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_service_error(self):
|
||||
"""Test decomposition handles service returning error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": False,
|
||||
"error": "Internal error",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgentExternal:
|
||||
"""Test generate_agent_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_success(self):
|
||||
"""Test successful agent generation."""
|
||||
agent_json = {
|
||||
"name": "Test Agent",
|
||||
"nodes": [],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": agent_json,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_external(instructions)
|
||||
|
||||
assert result == agent_json
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_handles_error(self):
|
||||
"""Test agent generation handles errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_external({"steps": []})
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateAgentPatchExternal:
|
||||
"""Test generate_agent_patch_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_patch_returns_updated_agent(self):
|
||||
"""Test successful patch generation returning updated agent."""
|
||||
updated_agent = {
|
||||
"name": "Updated Agent",
|
||||
"nodes": [{"id": "1", "block_id": "test"}],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": updated_agent,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add a new node", current_agent
|
||||
)
|
||||
|
||||
assert result == updated_agent
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": "Add a new node",
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_patch_returns_clarifying_questions(self):
|
||||
"""Test patch generation returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add something", {"nodes": []}
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Test health_check function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset singletons before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_false_when_not_configured(self):
|
||||
"""Test health check returns False when service not configured."""
|
||||
with patch.object(
|
||||
service, "is_external_service_configured", return_value=False
|
||||
):
|
||||
result = await service.health_check()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_true_when_healthy(self):
|
||||
"""Test health check returns True when service is healthy."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"status": "healthy",
|
||||
"blocks_loaded": True,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is True
|
||||
mock_client.get.assert_called_once_with("/health")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_false_when_not_healthy(self):
|
||||
"""Test health check returns False when service is not healthy."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"status": "unhealthy",
|
||||
"blocks_loaded": False,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_false_on_error(self):
|
||||
"""Test health check returns False on connection error."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetBlocksExternal:
|
||||
"""Test get_blocks_external function."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_blocks_success(self):
|
||||
"""Test successful blocks retrieval."""
|
||||
blocks = [
|
||||
{"id": "block1", "name": "Block 1"},
|
||||
{"id": "block2", "name": "Block 2"},
|
||||
]
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"blocks": blocks,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result == blocks
|
||||
mock_client.get.assert_called_once_with("/api/blocks")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_blocks_handles_error(self):
|
||||
"""Test blocks retrieval handles errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -29,4 +29,4 @@ NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
NEXT_PUBLIC_TURNSTILE=disabled
|
||||
|
||||
# PR previews
|
||||
NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
||||
NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
||||
@@ -175,8 +175,6 @@ While server components and actions are cool and cutting-edge, they introduce a
|
||||
|
||||
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
|
||||
- Co-locate UI state inside components/hooks; keep global state minimal
|
||||
- Avoid `useMemo` and `useCallback` unless you have a measured performance issue
|
||||
- Do not abuse `useEffect`; prefer state colocation and derive values directly when possible
|
||||
|
||||
### Styling and components
|
||||
|
||||
@@ -551,48 +549,9 @@ Files:
|
||||
Types:
|
||||
|
||||
- Prefer `interface` for object shapes
|
||||
- Component props should be `interface Props { ... }` (not exported)
|
||||
- Only use specific exported names (e.g., `export interface MyComponentProps`) when the interface needs to be used outside the component
|
||||
- Keep type definitions inline with the component - do not create separate `types.ts` files unless types are shared across multiple files
|
||||
- Component props should be `interface Props { ... }`
|
||||
- Use precise types; avoid `any` and unsafe casts
|
||||
|
||||
**Props naming examples:**
|
||||
|
||||
```tsx
|
||||
// ✅ Good - internal props, not exported
|
||||
interface Props {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function Modal({ title, onClose }: Props) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// ✅ Good - exported when needed externally
|
||||
export interface ModalProps {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function Modal({ title, onClose }: ModalProps) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// ❌ Bad - unnecessarily specific name for internal use
|
||||
interface ModalComponentProps {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
// ❌ Bad - separate types.ts file for single component
|
||||
// types.ts
|
||||
export interface ModalProps { ... }
|
||||
|
||||
// Modal.tsx
|
||||
import type { ModalProps } from './types';
|
||||
```
|
||||
|
||||
Parameters:
|
||||
|
||||
- If more than one parameter is needed, pass a single `Args` object for clarity
|
||||
|
||||
@@ -16,12 +16,6 @@ export default defineConfig({
|
||||
client: "react-query",
|
||||
httpClient: "fetch",
|
||||
indexFiles: false,
|
||||
mock: {
|
||||
type: "msw",
|
||||
baseUrl: "http://localhost:3000/api/proxy",
|
||||
generateEachHttpStatus: true,
|
||||
delay: 0,
|
||||
},
|
||||
override: {
|
||||
mutator: {
|
||||
path: "./mutators/custom-mutator.ts",
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
"types": "tsc --noEmit",
|
||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
|
||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
|
||||
"test:unit": "vitest run",
|
||||
"test:unit:watch": "vitest",
|
||||
"test:no-build": "playwright test",
|
||||
"gentests": "playwright codegen http://localhost:3000",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
@@ -120,7 +118,6 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.1.2",
|
||||
"happy-dom": "20.3.4",
|
||||
"@opentelemetry/instrumentation": "0.209.0",
|
||||
"@playwright/test": "1.56.1",
|
||||
"@storybook/addon-a11y": "9.1.5",
|
||||
@@ -130,8 +127,6 @@
|
||||
"@storybook/nextjs": "9.1.5",
|
||||
"@tanstack/eslint-plugin-query": "5.91.2",
|
||||
"@tanstack/react-query-devtools": "5.90.2",
|
||||
"@testing-library/dom": "10.4.1",
|
||||
"@testing-library/react": "16.3.2",
|
||||
"@types/canvas-confetti": "1.9.0",
|
||||
"@types/lodash": "4.17.20",
|
||||
"@types/negotiator": "0.6.4",
|
||||
@@ -140,7 +135,6 @@
|
||||
"@types/react-dom": "18.3.5",
|
||||
"@types/react-modal": "3.16.3",
|
||||
"@types/react-window": "1.8.8",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"axe-playwright": "2.2.2",
|
||||
"chromatic": "13.3.3",
|
||||
"concurrently": "9.2.1",
|
||||
@@ -159,9 +153,7 @@
|
||||
"require-in-the-middle": "8.0.1",
|
||||
"storybook": "9.1.5",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.9.3",
|
||||
"vite-tsconfig-paths": "6.0.4",
|
||||
"vitest": "4.0.17"
|
||||
"typescript": "5.9.3"
|
||||
},
|
||||
"msw": {
|
||||
"workerDirectory": [
|
||||
|
||||
1118
autogpt_platform/frontend/pnpm-lock.yaml
generated
1118
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useRef } from "react";
|
||||
|
||||
const LOGOUT_REDIRECT_DELAY_MS = 400;
|
||||
|
||||
function wait(ms: number): Promise<void> {
|
||||
return new Promise(function resolveAfterDelay(resolve) {
|
||||
setTimeout(resolve, ms);
|
||||
});
|
||||
}
|
||||
|
||||
export default function LogoutPage() {
|
||||
const { logOut } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const hasStartedRef = useRef(false);
|
||||
|
||||
useEffect(
|
||||
function handleLogoutEffect() {
|
||||
if (hasStartedRef.current) return;
|
||||
hasStartedRef.current = true;
|
||||
|
||||
async function runLogout() {
|
||||
try {
|
||||
await logOut();
|
||||
} catch {
|
||||
toast({
|
||||
title: "Failed to log out. Redirecting to login.",
|
||||
variant: "destructive",
|
||||
});
|
||||
} finally {
|
||||
await wait(LOGOUT_REDIRECT_DELAY_MS);
|
||||
router.replace("/login");
|
||||
}
|
||||
}
|
||||
|
||||
void runLogout();
|
||||
},
|
||||
[logOut, router, toast],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center px-4">
|
||||
<div className="flex flex-col items-center justify-center gap-4 py-8">
|
||||
<LoadingSpinner size="large" />
|
||||
<Text variant="body" className="text-center">
|
||||
Logging you out...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,8 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Cpu } from "@phosphor-icons/react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -29,11 +26,6 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "LLM Registry",
|
||||
href: "/admin/llms",
|
||||
icon: <Cpu size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
|
||||
// Generated API functions
|
||||
import {
|
||||
getV2ListLlmProviders,
|
||||
postV2CreateLlmProvider,
|
||||
patchV2UpdateLlmProvider,
|
||||
deleteV2DeleteLlmProvider,
|
||||
getV2ListLlmModels,
|
||||
postV2CreateLlmModel,
|
||||
patchV2UpdateLlmModel,
|
||||
patchV2ToggleLlmModelAvailability,
|
||||
deleteV2DeleteLlmModelAndMigrateWorkflows,
|
||||
getV2GetModelUsageCount,
|
||||
getV2ListModelMigrations,
|
||||
postV2RevertAModelMigration,
|
||||
getV2ListModelCreators,
|
||||
postV2CreateModelCreator,
|
||||
patchV2UpdateModelCreator,
|
||||
deleteV2DeleteModelCreator,
|
||||
postV2SetRecommendedModel,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
|
||||
// Generated types
|
||||
import type { LlmProvidersResponse } from "@/app/api/__generated__/models/llmProvidersResponse";
|
||||
import type { LlmModelsResponse } from "@/app/api/__generated__/models/llmModelsResponse";
|
||||
import type { UpsertLlmProviderRequest } from "@/app/api/__generated__/models/upsertLlmProviderRequest";
|
||||
import type { CreateLlmModelRequest } from "@/app/api/__generated__/models/createLlmModelRequest";
|
||||
import type { UpdateLlmModelRequest } from "@/app/api/__generated__/models/updateLlmModelRequest";
|
||||
import type { ToggleLlmModelRequest } from "@/app/api/__generated__/models/toggleLlmModelRequest";
|
||||
import type { LlmMigrationsResponse } from "@/app/api/__generated__/models/llmMigrationsResponse";
|
||||
import type { LlmCreatorsResponse } from "@/app/api/__generated__/models/llmCreatorsResponse";
|
||||
import type { UpsertLlmCreatorRequest } from "@/app/api/__generated__/models/upsertLlmCreatorRequest";
|
||||
import type { LlmModelUsageResponse } from "@/app/api/__generated__/models/llmModelUsageResponse";
|
||||
import { LlmCostUnit } from "@/app/api/__generated__/models/llmCostUnit";
|
||||
|
||||
const ADMIN_LLM_PATH = "/admin/llms";
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Extracts and validates a required string field from FormData.
|
||||
* Throws an error if the field is missing or empty.
|
||||
*/
|
||||
function getRequiredFormField(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): string {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = raw ? String(raw).trim() : "";
|
||||
if (!value) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and validates a required positive number field from FormData.
|
||||
* Throws an error if the field is missing, empty, or not a positive number.
|
||||
*/
|
||||
function getRequiredPositiveNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value) || value <= 0) {
|
||||
throw new Error(`${displayName || fieldName} must be a positive number`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and validates a required number field from FormData.
|
||||
* Throws an error if the field is missing, empty, or not a finite number.
|
||||
*/
|
||||
function getRequiredNumber(
|
||||
formData: FormData,
|
||||
fieldName: string,
|
||||
displayName?: string,
|
||||
): number {
|
||||
const raw = formData.get(fieldName);
|
||||
const value = Number(raw);
|
||||
if (raw === null || raw === "" || !Number.isFinite(value)) {
|
||||
throw new Error(`${displayName || fieldName} is required`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Provider Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmProviders(): Promise<LlmProvidersResponse> {
|
||||
const response = await getV2ListLlmProviders({ include_models: true });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch LLM providers");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmProviderAction(formData: FormData) {
|
||||
const payload: UpsertLlmProviderRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
supports_tools: formData.getAll("supports_tools").includes("on"),
|
||||
supports_json_output: formData
|
||||
.getAll("supports_json_output")
|
||||
.includes("on"),
|
||||
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
|
||||
supports_parallel_tool: formData
|
||||
.getAll("supports_parallel_tool")
|
||||
.includes("on"),
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await postV2CreateLlmProvider(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create LLM provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmProviderAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const providerId = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider id",
|
||||
);
|
||||
|
||||
const response = await deleteV2DeleteLlmProvider(providerId);
|
||||
if (response.status !== 200) {
|
||||
const errorData = response.data as { detail?: string };
|
||||
throw new Error(errorData?.detail || "Failed to delete provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmProviderAction(formData: FormData) {
|
||||
const providerId = getRequiredFormField(
|
||||
formData,
|
||||
"provider_id",
|
||||
"Provider id",
|
||||
);
|
||||
|
||||
const payload: UpsertLlmProviderRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider")).trim()
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id")).trim()
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type")).trim()
|
||||
: "api_key",
|
||||
supports_tools: formData.getAll("supports_tools").includes("on"),
|
||||
supports_json_output: formData
|
||||
.getAll("supports_json_output")
|
||||
.includes("on"),
|
||||
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
|
||||
supports_parallel_tool: formData
|
||||
.getAll("supports_parallel_tool")
|
||||
.includes("on"),
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateLlmProvider(providerId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update LLM provider");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmModels(): Promise<LlmModelsResponse> {
|
||||
const response = await getV2ListLlmModels();
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch LLM models");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmModelAction(formData: FormData) {
|
||||
const providerId = getRequiredFormField(formData, "provider_id", "Provider");
|
||||
const creatorId = formData.get("creator_id");
|
||||
const contextWindow = getRequiredPositiveNumber(
|
||||
formData,
|
||||
"context_window",
|
||||
"Context window",
|
||||
);
|
||||
const creditCost = getRequiredNumber(formData, "credit_cost", "Credit cost");
|
||||
|
||||
// Fetch provider to get default credentials
|
||||
const providersResponse = await getV2ListLlmProviders({
|
||||
include_models: false,
|
||||
});
|
||||
if (providersResponse.status !== 200) {
|
||||
throw new Error("Failed to fetch providers");
|
||||
}
|
||||
const provider = providersResponse.data.providers.find(
|
||||
(p) => p.id === providerId,
|
||||
);
|
||||
|
||||
if (!provider) {
|
||||
throw new Error("Provider not found");
|
||||
}
|
||||
|
||||
const payload: CreateLlmModelRequest = {
|
||||
slug: String(formData.get("slug") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: providerId,
|
||||
creator_id: creatorId ? String(creatorId) : undefined,
|
||||
context_window: contextWindow,
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.getAll("is_enabled").includes("on"),
|
||||
capabilities: {},
|
||||
metadata: {},
|
||||
costs: [
|
||||
{
|
||||
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
|
||||
credit_cost: creditCost,
|
||||
credential_provider:
|
||||
provider.default_credential_provider || provider.name,
|
||||
credential_id: provider.default_credential_id || undefined,
|
||||
credential_type: provider.default_credential_type || "api_key",
|
||||
metadata: {},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const response = await postV2CreateLlmModel(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmModelAction(formData: FormData) {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const creatorId = formData.get("creator_id");
|
||||
|
||||
const payload: UpdateLlmModelRequest = {
|
||||
display_name: formData.get("display_name")
|
||||
? String(formData.get("display_name"))
|
||||
: undefined,
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: formData.get("provider_id")
|
||||
? String(formData.get("provider_id"))
|
||||
: undefined,
|
||||
creator_id: creatorId ? String(creatorId) : undefined,
|
||||
context_window: formData.get("context_window")
|
||||
? Number(formData.get("context_window"))
|
||||
: undefined,
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.has("is_enabled")
|
||||
? formData.getAll("is_enabled").includes("on")
|
||||
: undefined,
|
||||
costs: formData.get("credit_cost")
|
||||
? [
|
||||
{
|
||||
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
|
||||
credit_cost: Number(formData.get("credit_cost")),
|
||||
credential_provider: String(
|
||||
formData.get("credential_provider") || "",
|
||||
).trim(),
|
||||
credential_id: formData.get("credential_id")
|
||||
? String(formData.get("credential_id"))
|
||||
: undefined,
|
||||
credential_type: formData.get("credential_type")
|
||||
? String(formData.get("credential_type"))
|
||||
: undefined,
|
||||
metadata: {},
|
||||
},
|
||||
]
|
||||
: undefined,
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateLlmModel(modelId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function toggleLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const shouldEnable = formData.get("is_enabled") === "true";
|
||||
const migrateToSlug = formData.get("migrate_to_slug");
|
||||
const migrationReason = formData.get("migration_reason");
|
||||
const customCreditCost = formData.get("custom_credit_cost");
|
||||
|
||||
const payload: ToggleLlmModelRequest = {
|
||||
is_enabled: shouldEnable,
|
||||
migrate_to_slug: migrateToSlug ? String(migrateToSlug) : undefined,
|
||||
migration_reason: migrationReason ? String(migrationReason) : undefined,
|
||||
custom_credit_cost: customCreditCost ? Number(customCreditCost) : undefined,
|
||||
};
|
||||
|
||||
const response = await patchV2ToggleLlmModelAvailability(modelId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to toggle LLM model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmModelAction(formData: FormData): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
const rawReplacement = formData.get("replacement_model_slug");
|
||||
const replacementModelSlug =
|
||||
rawReplacement && String(rawReplacement).trim()
|
||||
? String(rawReplacement).trim()
|
||||
: undefined;
|
||||
|
||||
const response = await deleteV2DeleteLlmModelAndMigrateWorkflows(modelId, {
|
||||
replacement_model_slug: replacementModelSlug,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to delete model");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function fetchLlmModelUsage(
|
||||
modelId: string,
|
||||
): Promise<LlmModelUsageResponse> {
|
||||
const response = await getV2GetModelUsageCount(modelId);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch model usage");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Migration Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmMigrations(
|
||||
includeReverted: boolean = false,
|
||||
): Promise<LlmMigrationsResponse> {
|
||||
const response = await getV2ListModelMigrations({
|
||||
include_reverted: includeReverted,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch migrations");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function revertLlmMigrationAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const migrationId = getRequiredFormField(
|
||||
formData,
|
||||
"migration_id",
|
||||
"Migration id",
|
||||
);
|
||||
|
||||
const response = await postV2RevertAModelMigration(migrationId, null);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to revert migration");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Creator Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function fetchLlmCreators(): Promise<LlmCreatorsResponse> {
|
||||
const response = await getV2ListModelCreators();
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to fetch creators");
|
||||
}
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const payload: UpsertLlmCreatorRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
website_url: formData.get("website_url")
|
||||
? String(formData.get("website_url")).trim()
|
||||
: undefined,
|
||||
logo_url: formData.get("logo_url")
|
||||
? String(formData.get("logo_url")).trim()
|
||||
: undefined,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await postV2CreateModelCreator(payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
|
||||
|
||||
const payload: UpsertLlmCreatorRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
website_url: formData.get("website_url")
|
||||
? String(formData.get("website_url")).trim()
|
||||
: undefined,
|
||||
logo_url: formData.get("logo_url")
|
||||
? String(formData.get("logo_url")).trim()
|
||||
: undefined,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const response = await patchV2UpdateModelCreator(creatorId, payload);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to update creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function deleteLlmCreatorAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
|
||||
|
||||
const response = await deleteV2DeleteModelCreator(creatorId);
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to delete creator");
|
||||
}
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Recommended Model Actions
|
||||
// =============================================================================
|
||||
|
||||
export async function setRecommendedModelAction(
|
||||
formData: FormData,
|
||||
): Promise<void> {
|
||||
const modelId = getRequiredFormField(formData, "model_id", "Model id");
|
||||
|
||||
const response = await postV2SetRecommendedModel({ model_id: modelId });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to set recommended model");
|
||||
}
|
||||
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddCreatorModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Creator
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Add a new model creator (the organization that made/trained the
|
||||
model).
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Name (slug) <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Lowercase identifier (e.g., openai, meta, anthropic)
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={2}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Creator of GPT models..."
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="website_url"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Website URL
|
||||
</label>
|
||||
<input
|
||||
id="website_url"
|
||||
name="website_url"
|
||||
type="url"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="https://openai.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Add Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import { createLlmModelAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
export function AddModelModal({ providers, creators }: Props) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedCreatorId, setSelectedCreatorId] = useState("");
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to create model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// When provider changes, auto-select matching creator if one exists
|
||||
function handleProviderChange(providerId: string) {
|
||||
const provider = providers.find((p) => p.id === providerId);
|
||||
if (provider) {
|
||||
// Find creator with same name as provider (e.g., "openai" -> "openai")
|
||||
const matchingCreator = creators.find((c) => c.name === provider.name);
|
||||
if (matchingCreator) {
|
||||
setSelectedCreatorId(matchingCreator.id);
|
||||
} else {
|
||||
// No matching creator (e.g., OpenRouter hosts other creators' models)
|
||||
setSelectedCreatorId("");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Model
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Register a new model slug, metadata, and pricing.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core model details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="slug"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Model Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="slug"
|
||||
required
|
||||
name="slug"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="gpt-4.1-mini-2025-04-14"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="GPT 4.1 Mini"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model Configuration */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Model Configuration
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Model capabilities and limits
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="provider_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<select
|
||||
id="provider_id"
|
||||
required
|
||||
name="provider_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
defaultValue=""
|
||||
onChange={(e) => handleProviderChange(e.target.value)}
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((provider) => (
|
||||
<option key={provider.id} value={provider.id}>
|
||||
{provider.display_name} ({provider.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="creator_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Creator
|
||||
</label>
|
||||
<select
|
||||
id="creator_id"
|
||||
name="creator_id"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
value={selectedCreatorId}
|
||||
onChange={(e) => setSelectedCreatorId(e.target.value)}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((creator) => (
|
||||
<option key={creator.id} value={creator.id}>
|
||||
{creator.display_name} ({creator.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="context_window"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Context Window <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="context_window"
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="128000"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="max_output_tokens"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Max Output Tokens
|
||||
</label>
|
||||
<input
|
||||
id="max_output_tokens"
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="16384"
|
||||
min={1}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Pricing */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">Pricing</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost per run (credentials are managed via the provider)
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-1">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="credit_cost"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credit Cost <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="credit_cost"
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
step="1"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="5"
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credit cost is always in platform credits. Credentials are
|
||||
inherited from the selected provider.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Enabled Toggle */}
|
||||
<div className="flex items-center gap-3 border-t border-border pt-6">
|
||||
<input type="hidden" name="is_enabled" value="off" />
|
||||
<input
|
||||
id="is_enabled"
|
||||
type="checkbox"
|
||||
name="is_enabled"
|
||||
defaultChecked
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor="is_enabled"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Enabled by default
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { createLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function AddProviderModal() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await createLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to create provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Add Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="primary" size="small">
|
||||
Add Provider
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Define a new upstream provider and default credential information.
|
||||
</div>
|
||||
|
||||
{/* Setup Instructions */}
|
||||
<div className="mb-6 rounded-lg border border-primary/30 bg-primary/5 p-4">
|
||||
<div className="space-y-2">
|
||||
<h4 className="text-sm font-semibold text-foreground">
|
||||
Before Adding a Provider
|
||||
</h4>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
To use a new provider, you must first configure its credentials in
|
||||
the backend:
|
||||
</p>
|
||||
<ol className="list-inside list-decimal space-y-1 text-xs text-muted-foreground">
|
||||
<li>
|
||||
Add the credential to{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/integrations/credentials_store.py
|
||||
</code>{" "}
|
||||
with a UUID, provider name, and settings secret reference
|
||||
</li>
|
||||
<li>
|
||||
Add it to the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
backend/data/block_cost_config.py
|
||||
</code>
|
||||
</li>
|
||||
<li>
|
||||
Use the <strong>same provider name</strong> in the
|
||||
"Credential Provider" field below that matches the key
|
||||
in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
required
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
<strong>Important:</strong> This must exactly match the key in
|
||||
the{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>{" "}
|
||||
dictionary in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
block_cost_config.py
|
||||
</code>
|
||||
. Common values: "openai", "anthropic",
|
||||
"groq", "open_router", etc.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{ name: "supports_tools", label: "Supports tools" },
|
||||
{ name: "supports_json_output", label: "Supports JSON output" },
|
||||
{ name: "supports_reasoning", label: "Supports reasoning" },
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
},
|
||||
].map(({ name, label }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={
|
||||
name !== "supports_reasoning" &&
|
||||
name !== "supports_parallel_tool"
|
||||
}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Creating..." : "Save Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { updateLlmCreatorAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { DeleteCreatorModal } from "./DeleteCreatorModal";
|
||||
|
||||
export function CreatorsTable({ creators }: { creators: LlmModelCreator[] }) {
|
||||
if (!creators.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No creators registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Description</TableHead>
|
||||
<TableHead>Website</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{creators.map((creator) => (
|
||||
<TableRow key={creator.id}>
|
||||
<TableCell>
|
||||
<div className="font-medium">{creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{creator.name}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{creator.description || "—"}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{creator.website_url ? (
|
||||
<a
|
||||
href={creator.website_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-sm text-primary hover:underline"
|
||||
>
|
||||
{(() => {
|
||||
try {
|
||||
return new URL(creator.website_url).hostname;
|
||||
} catch {
|
||||
return creator.website_url;
|
||||
}
|
||||
})()}
|
||||
</a>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<EditCreatorModal creator={creator} />
|
||||
<DeleteCreatorModal creator={creator} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EditCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update creator");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "512px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.id} />
|
||||
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Name (slug)</label>
|
||||
<input
|
||||
required
|
||||
name="name"
|
||||
defaultValue={creator.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Display Name</label>
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={creator.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Description</label>
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={creator.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Website URL</label>
|
||||
<input
|
||||
name="website_url"
|
||||
type="url"
|
||||
defaultValue={creator.website_url ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import { deleteLlmCreatorAction } from "../actions";
|
||||
|
||||
export function DeleteCreatorModal({ creator }: { creator: LlmModelCreator }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmCreatorAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete creator");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Creator"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{creator.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({creator.name})
|
||||
</span>
|
||||
</p>
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Models using this creator will have their creator field
|
||||
cleared. This is safe and won't affect model
|
||||
functionality.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="creator_id" value={creator.id} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Creator"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { deleteLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DeleteModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedReplacement, setSelectedReplacement] = useState<string>("");
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [usageLoading, setUsageLoading] = useState(false);
|
||||
const [usageError, setUsageError] = useState<string | null>(null);
|
||||
|
||||
// Filter out the current model and disabled models from replacement options
|
||||
const replacementOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
// Check if migration is required (has blocks using this model)
|
||||
const requiresMigration = usageCount !== null && usageCount > 0;
|
||||
|
||||
async function fetchUsage() {
|
||||
setUsageLoading(true);
|
||||
setUsageError(null);
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.id);
|
||||
setUsageCount(usage.node_count);
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch model usage:", err);
|
||||
setUsageError("Failed to load usage count");
|
||||
setUsageCount(null);
|
||||
} finally {
|
||||
setUsageLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to delete model");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if delete button should be enabled
|
||||
const canDelete =
|
||||
!isDeleting &&
|
||||
!usageLoading &&
|
||||
usageCount !== null &&
|
||||
(requiresMigration
|
||||
? selectedReplacement && replacementOptions.length > 0
|
||||
: true);
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
setUsageCount(null);
|
||||
setUsageError(null);
|
||||
setError(null);
|
||||
setSelectedReplacement("");
|
||||
await fetchUsage();
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
{requiresMigration
|
||||
? "This action cannot be undone. All workflows using this model will be migrated to the replacement model you select."
|
||||
: "This action cannot be undone."}
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{usageLoading && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage count...
|
||||
</p>
|
||||
)}
|
||||
{usageError && (
|
||||
<p className="mt-2 text-destructive">{usageError}</p>
|
||||
)}
|
||||
{!usageLoading && !usageError && usageCount !== null && (
|
||||
<p className="mt-2 font-semibold">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model
|
||||
</p>
|
||||
)}
|
||||
{requiresMigration && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
All workflows currently using this model will be
|
||||
automatically updated to use the replacement model you
|
||||
choose below.
|
||||
</p>
|
||||
)}
|
||||
{!usageLoading && usageCount === 0 && (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are using this model. It can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
<input
|
||||
type="hidden"
|
||||
name="replacement_model_slug"
|
||||
value={selectedReplacement}
|
||||
/>
|
||||
|
||||
{requiresMigration && (
|
||||
<label className="text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Select Replacement Model{" "}
|
||||
<span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedReplacement}
|
||||
onChange={(e) => setSelectedReplacement(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{replacementOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{replacementOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No replacement models available. You must have at least one
|
||||
other enabled model before deleting this one.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setSelectedReplacement("");
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={!canDelete}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting
|
||||
? "Deleting..."
|
||||
: requiresMigration
|
||||
? "Delete and Migrate"
|
||||
: "Delete"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { deleteLlmProviderAction } from "../actions";
|
||||
|
||||
export function DeleteProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
const modelCount = provider.models?.length ?? 0;
|
||||
const hasModels = modelCount > 0;
|
||||
|
||||
async function handleDelete(formData: FormData) {
|
||||
setIsDeleting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await deleteLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to delete provider",
|
||||
);
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Delete Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0 text-destructive hover:bg-destructive/10"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-4">
|
||||
<div
|
||||
className={`rounded-lg border p-4 ${
|
||||
hasModels
|
||||
? "border-destructive/30 bg-destructive/10"
|
||||
: "border-amber-500/30 bg-amber-500/10 dark:border-amber-400/30 dark:bg-amber-400/10"
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
<div
|
||||
className={`flex-shrink-0 ${
|
||||
hasModels
|
||||
? "text-destructive"
|
||||
: "text-amber-600 dark:text-amber-400"
|
||||
}`}
|
||||
>
|
||||
{hasModels ? "🚫" : "⚠️"}
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to delete:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{provider.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">
|
||||
({provider.name})
|
||||
</span>
|
||||
</p>
|
||||
{hasModels ? (
|
||||
<p className="mt-2 text-destructive">
|
||||
This provider has {modelCount} model(s). You must delete all
|
||||
models before you can delete this provider.
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
This provider has no models and can be safely deleted.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form action={handleDelete} className="space-y-4">
|
||||
<input type="hidden" name="provider_id" value={provider.id} />
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
type="button"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={isDeleting || hasModels}
|
||||
className="bg-destructive text-destructive-foreground hover:bg-destructive/90 disabled:opacity-50"
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete Provider"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { toggleLlmModelAction, fetchLlmModelUsage } from "../actions";
|
||||
|
||||
export function DisableModelModal({
|
||||
model,
|
||||
availableModels,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
availableModels: LlmModel[];
|
||||
}) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isDisabling, setIsDisabling] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [usageCount, setUsageCount] = useState<number | null>(null);
|
||||
const [selectedMigration, setSelectedMigration] = useState<string>("");
|
||||
const [wantsMigration, setWantsMigration] = useState(false);
|
||||
const [migrationReason, setMigrationReason] = useState("");
|
||||
const [customCreditCost, setCustomCreditCost] = useState<string>("");
|
||||
|
||||
// Filter out the current model and disabled models from replacement options
|
||||
const migrationOptions = availableModels.filter(
|
||||
(m) => m.id !== model.id && m.is_enabled,
|
||||
);
|
||||
|
||||
async function fetchUsage() {
|
||||
try {
|
||||
const usage = await fetchLlmModelUsage(model.id);
|
||||
setUsageCount(usage.node_count);
|
||||
} catch {
|
||||
setUsageCount(null);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDisable(formData: FormData) {
|
||||
setIsDisabling(true);
|
||||
setError(null);
|
||||
try {
|
||||
await toggleLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to disable model");
|
||||
} finally {
|
||||
setIsDisabling(false);
|
||||
}
|
||||
}
|
||||
|
||||
function resetState() {
|
||||
setError(null);
|
||||
setSelectedMigration("");
|
||||
setWantsMigration(false);
|
||||
setMigrationReason("");
|
||||
setCustomCreditCost("");
|
||||
}
|
||||
|
||||
const hasUsage = usageCount !== null && usageCount > 0;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Disable Model"
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: async (isOpen) => {
|
||||
setOpen(isOpen);
|
||||
if (isOpen) {
|
||||
setUsageCount(null);
|
||||
resetState();
|
||||
await fetchUsage();
|
||||
}
|
||||
},
|
||||
}}
|
||||
styling={{ maxWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
className="min-w-0"
|
||||
>
|
||||
Disable
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Disabling a model will hide it from users when creating new workflows.
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
|
||||
⚠️
|
||||
</div>
|
||||
<div className="text-sm text-foreground">
|
||||
<p className="font-semibold">You are about to disable:</p>
|
||||
<p className="mt-1">
|
||||
<span className="font-medium">{model.display_name}</span>{" "}
|
||||
<span className="text-muted-foreground">({model.slug})</span>
|
||||
</p>
|
||||
{usageCount === null ? (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Loading usage data...
|
||||
</p>
|
||||
) : usageCount > 0 ? (
|
||||
<p className="mt-2 font-semibold">
|
||||
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
|
||||
currently use this model
|
||||
</p>
|
||||
) : (
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
No workflows are currently using this model.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{hasUsage && (
|
||||
<div className="space-y-4 rounded-lg border border-border bg-muted/50 p-4">
|
||||
<label className="flex items-start gap-3">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={wantsMigration}
|
||||
onChange={(e) => {
|
||||
setWantsMigration(e.target.checked);
|
||||
if (!e.target.checked) {
|
||||
setSelectedMigration("");
|
||||
}
|
||||
}}
|
||||
className="mt-1"
|
||||
/>
|
||||
<div className="text-sm">
|
||||
<span className="font-medium">
|
||||
Migrate existing workflows to another model
|
||||
</span>
|
||||
<p className="mt-1 text-muted-foreground">
|
||||
Creates a revertible migration record. If unchecked,
|
||||
existing workflows will use automatic fallback to an enabled
|
||||
model from the same provider.
|
||||
</p>
|
||||
</div>
|
||||
</label>
|
||||
|
||||
{wantsMigration && (
|
||||
<div className="space-y-4 border-t border-border pt-4">
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Replacement Model{" "}
|
||||
<span className="text-destructive">*</span>
|
||||
</span>
|
||||
<select
|
||||
required
|
||||
value={selectedMigration}
|
||||
onChange={(e) => setSelectedMigration(e.target.value)}
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="">-- Choose a replacement model --</option>
|
||||
{migrationOptions.map((m) => (
|
||||
<option key={m.id} value={m.slug}>
|
||||
{m.display_name} ({m.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
{migrationOptions.length === 0 && (
|
||||
<p className="mt-2 text-xs text-destructive">
|
||||
No other enabled models available for migration.
|
||||
</p>
|
||||
)}
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Migration Reason{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="text"
|
||||
value={migrationReason}
|
||||
onChange={(e) => setMigrationReason(e.target.value)}
|
||||
placeholder="e.g., Provider outage, Cost reduction"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
Helps track why the migration was made
|
||||
</p>
|
||||
</label>
|
||||
|
||||
<label className="block text-sm font-medium">
|
||||
<span className="mb-2 block">
|
||||
Custom Credit Cost{" "}
|
||||
<span className="font-normal text-muted-foreground">
|
||||
(optional)
|
||||
</span>
|
||||
</span>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={customCreditCost}
|
||||
onChange={(e) => setCustomCreditCost(e.target.value)}
|
||||
placeholder="Leave blank to use target model's cost"
|
||||
className="w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
Override pricing for migrated workflows. When set, billing
|
||||
will use this cost instead of the target model's
|
||||
cost.
|
||||
</p>
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form action={handleDisable} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
<input type="hidden" name="is_enabled" value="false" />
|
||||
{wantsMigration && selectedMigration && (
|
||||
<>
|
||||
<input
|
||||
type="hidden"
|
||||
name="migrate_to_slug"
|
||||
value={selectedMigration}
|
||||
/>
|
||||
{migrationReason && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="migration_reason"
|
||||
value={migrationReason}
|
||||
/>
|
||||
)}
|
||||
{customCreditCost && (
|
||||
<input
|
||||
type="hidden"
|
||||
name="custom_credit_cost"
|
||||
value={customCreditCost}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
resetState();
|
||||
}}
|
||||
disabled={isDisabling}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
size="small"
|
||||
disabled={
|
||||
isDisabling ||
|
||||
(wantsMigration && !selectedMigration) ||
|
||||
usageCount === null
|
||||
}
|
||||
>
|
||||
{isDisabling
|
||||
? "Disabling..."
|
||||
: wantsMigration && selectedMigration
|
||||
? "Disable & Migrate"
|
||||
: "Disable Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { updateLlmModelAction } from "../actions";
|
||||
|
||||
export function EditModelModal({
|
||||
model,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providers.find((p) => p.id === model.provider_id);
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to update model");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Model"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small" className="min-w-0">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update model metadata and pricing information.
|
||||
</div>
|
||||
{error && (
|
||||
<div className="mb-4 rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
<form action={handleSubmit} className="space-y-4">
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Display Name
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={model.display_name}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Provider
|
||||
<select
|
||||
required
|
||||
name="provider_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.provider_id}
|
||||
>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.id}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who hosts/serves the model
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Creator
|
||||
<select
|
||||
name="creator_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.creator_id ?? ""}
|
||||
>
|
||||
<option value="">No creator selected</option>
|
||||
{creators.map((c) => (
|
||||
<option key={c.id} value={c.id}>
|
||||
{c.display_name} ({c.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Who made/trained the model (e.g., OpenAI, Meta)
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label className="text-sm font-medium">
|
||||
Description
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={model.description ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</label>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Context Window
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
defaultValue={model.context_window}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Max Output Tokens
|
||||
<input
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
defaultValue={model.max_output_tokens ?? undefined}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Credit Cost
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
defaultValue={cost?.credit_cost ?? 0}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={0}
|
||||
/>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Credits charged per run
|
||||
</span>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Credential Provider
|
||||
<select
|
||||
required
|
||||
name="credential_provider"
|
||||
defaultValue={cost?.credential_provider ?? provider?.name ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.name}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Must match a key in PROVIDER_CREDENTIALS
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
{/* Hidden defaults for credential_type and unit */}
|
||||
<input
|
||||
type="hidden"
|
||||
name="credential_type"
|
||||
value={
|
||||
cost?.credential_type ??
|
||||
provider?.default_credential_type ??
|
||||
"api_key"
|
||||
}
|
||||
/>
|
||||
<input type="hidden" name="unit" value={cost?.unit ?? "RUN"} />
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => setOpen(false)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update Model"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { updateLlmProviderAction } from "../actions";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
|
||||
export function EditProviderModal({ provider }: { provider: LlmProvider }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
async function handleSubmit(formData: FormData) {
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await updateLlmProviderAction(formData);
|
||||
setOpen(false);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to update provider",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Provider"
|
||||
controlled={{ isOpen: open, set: setOpen }}
|
||||
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button variant="outline" size="small">
|
||||
Edit
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="mb-4 text-sm text-muted-foreground">
|
||||
Update provider configuration and capabilities.
|
||||
</div>
|
||||
|
||||
<form action={handleSubmit} className="space-y-6">
|
||||
<input type="hidden" name="provider_id" value={provider.id} />
|
||||
|
||||
{/* Basic Information */}
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Basic Information
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Core provider details
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Provider Slug <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="name"
|
||||
required
|
||||
name="name"
|
||||
defaultValue={provider.name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="display_name"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Display Name <span className="text-destructive">*</span>
|
||||
</label>
|
||||
<input
|
||||
id="display_name"
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={provider.display_name}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="description"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Description
|
||||
</label>
|
||||
<textarea
|
||||
id="description"
|
||||
name="description"
|
||||
rows={3}
|
||||
defaultValue={provider.description ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Default Credentials */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Default Credentials
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Credential provider name that matches the key in{" "}
|
||||
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
|
||||
PROVIDER_CREDENTIALS
|
||||
</code>
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_provider"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Provider
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_provider"
|
||||
name="default_credential_provider"
|
||||
defaultValue={provider.default_credential_provider ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_id"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential ID
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_id"
|
||||
name="default_credential_id"
|
||||
defaultValue={provider.default_credential_id ?? ""}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional credential ID"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label
|
||||
htmlFor="default_credential_type"
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
Credential Type
|
||||
</label>
|
||||
<input
|
||||
id="default_credential_type"
|
||||
name="default_credential_type"
|
||||
defaultValue={provider.default_credential_type ?? "api_key"}
|
||||
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="api_key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Capabilities */}
|
||||
<div className="space-y-4 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Capabilities
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provider feature flags
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
{[
|
||||
{
|
||||
name: "supports_tools",
|
||||
label: "Supports tools",
|
||||
checked: provider.supports_tools,
|
||||
},
|
||||
{
|
||||
name: "supports_json_output",
|
||||
label: "Supports JSON output",
|
||||
checked: provider.supports_json_output,
|
||||
},
|
||||
{
|
||||
name: "supports_reasoning",
|
||||
label: "Supports reasoning",
|
||||
checked: provider.supports_reasoning,
|
||||
},
|
||||
{
|
||||
name: "supports_parallel_tool",
|
||||
label: "Supports parallel tool calls",
|
||||
checked: provider.supports_parallel_tool,
|
||||
},
|
||||
].map(({ name, label, checked }) => (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
|
||||
>
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
id={name}
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={checked}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
<label
|
||||
htmlFor={name}
|
||||
className="text-sm font-medium text-foreground"
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setError(null);
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { ErrorBoundary } from "@/components/molecules/ErrorBoundary/ErrorBoundary";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { AddProviderModal } from "./AddProviderModal";
|
||||
import { AddModelModal } from "./AddModelModal";
|
||||
import { AddCreatorModal } from "./AddCreatorModal";
|
||||
import { ProviderList } from "./ProviderList";
|
||||
import { ModelsTable } from "./ModelsTable";
|
||||
import { MigrationsTable } from "./MigrationsTable";
|
||||
import { CreatorsTable } from "./CreatorsTable";
|
||||
import { RecommendedModelSelector } from "./RecommendedModelSelector";
|
||||
|
||||
interface Props {
|
||||
providers: LlmProvider[];
|
||||
models: LlmModel[];
|
||||
migrations: LlmModelMigration[];
|
||||
creators: LlmModelCreator[];
|
||||
}
|
||||
|
||||
function AdminErrorFallback() {
|
||||
return (
|
||||
<div className="mx-auto max-w-xl p-6">
|
||||
<ErrorCard
|
||||
responseError={{
|
||||
message:
|
||||
"An error occurred while loading the LLM Registry. Please refresh the page.",
|
||||
}}
|
||||
context="llm-registry"
|
||||
onRetry={() => window.location.reload()}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function LlmRegistryDashboard({
|
||||
providers,
|
||||
models,
|
||||
migrations,
|
||||
creators,
|
||||
}: Props) {
|
||||
return (
|
||||
<ErrorBoundary fallback={<AdminErrorFallback />} context="llm-registry">
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Header */}
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">LLM Registry</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Manage providers, creators, models, and credit pricing
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Active Migrations Section - Only show if there are migrations */}
|
||||
{migrations.length > 0 && (
|
||||
<div className="rounded-lg border border-primary/30 bg-primary/5 p-6 shadow-sm">
|
||||
<div className="mb-4">
|
||||
<h2 className="text-xl font-semibold">Active Migrations</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
These migrations can be reverted to restore workflows to their
|
||||
original model
|
||||
</p>
|
||||
</div>
|
||||
<MigrationsTable migrations={migrations} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Providers & Creators Section - Side by Side */}
|
||||
<div className="grid gap-6 lg:grid-cols-2">
|
||||
{/* Providers */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Providers</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who hosts/serves the models
|
||||
</p>
|
||||
</div>
|
||||
<AddProviderModal />
|
||||
</div>
|
||||
<ProviderList providers={providers} />
|
||||
</div>
|
||||
|
||||
{/* Creators */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Creators</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Who made/trained the models
|
||||
</p>
|
||||
</div>
|
||||
<AddCreatorModal />
|
||||
</div>
|
||||
<CreatorsTable creators={creators} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Models Section */}
|
||||
<div className="rounded-lg border bg-card p-6 shadow-sm">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold">Models</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Toggle availability, adjust context windows, and update credit
|
||||
pricing
|
||||
</p>
|
||||
</div>
|
||||
<AddModelModal providers={providers} creators={creators} />
|
||||
</div>
|
||||
|
||||
{/* Recommended Model Selector */}
|
||||
<div className="mb-6">
|
||||
<RecommendedModelSelector models={models} />
|
||||
</div>
|
||||
|
||||
<ModelsTable
|
||||
models={models}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { revertLlmMigrationAction } from "../actions";
|
||||
|
||||
export function MigrationsTable({
|
||||
migrations,
|
||||
}: {
|
||||
migrations: LlmModelMigration[];
|
||||
}) {
|
||||
if (!migrations.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No active migrations. Migrations are created when you disable a model
|
||||
with the "Migrate existing workflows" option.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Migration</TableHead>
|
||||
<TableHead>Reason</TableHead>
|
||||
<TableHead>Nodes Affected</TableHead>
|
||||
<TableHead>Custom Cost</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{migrations.map((migration) => (
|
||||
<MigrationRow key={migration.id} migration={migration} />
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MigrationRow({ migration }: { migration: LlmModelMigration }) {
|
||||
const [isReverting, setIsReverting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
async function handleRevert(formData: FormData) {
|
||||
setIsReverting(true);
|
||||
setError(null);
|
||||
try {
|
||||
await revertLlmMigrationAction(formData);
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to revert migration",
|
||||
);
|
||||
} finally {
|
||||
setIsReverting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const createdDate = new Date(migration.created_at);
|
||||
|
||||
return (
|
||||
<>
|
||||
<TableRow>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
<span className="font-medium">{migration.source_model_slug}</span>
|
||||
<span className="mx-2 text-muted-foreground">→</span>
|
||||
<span className="font-medium">{migration.target_model_slug}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{migration.reason || "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">{migration.node_count}</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm">
|
||||
{migration.custom_credit_cost !== null &&
|
||||
migration.custom_credit_cost !== undefined
|
||||
? `${migration.custom_credit_cost} credits`
|
||||
: "—"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{createdDate.toLocaleDateString()}{" "}
|
||||
{createdDate.toLocaleTimeString([], {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<form action={handleRevert} className="inline">
|
||||
<input type="hidden" name="migration_id" value={migration.id} />
|
||||
<Button
|
||||
type="submit"
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={isReverting}
|
||||
>
|
||||
{isReverting ? "Reverting..." : "Revert"}
|
||||
</Button>
|
||||
</form>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
{error && (
|
||||
<TableRow>
|
||||
<TableCell colSpan={6}>
|
||||
<div className="rounded border border-destructive/30 bg-destructive/10 p-2 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { toggleLlmModelAction } from "../actions";
|
||||
import { DeleteModelModal } from "./DeleteModelModal";
|
||||
import { DisableModelModal } from "./DisableModelModal";
|
||||
import { EditModelModal } from "./EditModelModal";
|
||||
import { Star } from "@phosphor-icons/react";
|
||||
|
||||
export function ModelsTable({
|
||||
models,
|
||||
providers,
|
||||
creators,
|
||||
}: {
|
||||
models: LlmModel[];
|
||||
providers: LlmProvider[];
|
||||
creators: LlmModelCreator[];
|
||||
}) {
|
||||
if (!models.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No models registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const providerLookup = new Map(
|
||||
providers.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead>Provider</TableHead>
|
||||
<TableHead>Creator</TableHead>
|
||||
<TableHead>Context Window</TableHead>
|
||||
<TableHead>Max Output</TableHead>
|
||||
<TableHead>Cost</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{models.map((model) => {
|
||||
const cost = model.costs?.[0];
|
||||
const provider = providerLookup.get(model.provider_id);
|
||||
return (
|
||||
<TableRow
|
||||
key={model.id}
|
||||
className={model.is_enabled ? "" : "opacity-60"}
|
||||
>
|
||||
<TableCell>
|
||||
<div className="font-medium">{model.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.slug}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{provider ? (
|
||||
<>
|
||||
<div>{provider.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{provider.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
model.provider_id
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{model.creator ? (
|
||||
<>
|
||||
<div>{model.creator.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.creator.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-muted-foreground">—</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>{model.context_window.toLocaleString()}</TableCell>
|
||||
<TableCell>
|
||||
{model.max_output_tokens
|
||||
? model.max_output_tokens.toLocaleString()
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{cost ? (
|
||||
<>
|
||||
<div className="font-medium">
|
||||
{cost.credit_cost} credits
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{cost.credential_provider}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
"—"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span
|
||||
className={`inline-flex rounded-full px-2.5 py-1 text-xs font-semibold ${
|
||||
model.is_enabled
|
||||
? "bg-primary/10 text-primary"
|
||||
: "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{model.is_enabled ? "Enabled" : "Disabled"}
|
||||
</span>
|
||||
{model.is_recommended && (
|
||||
<span className="inline-flex items-center gap-1 rounded-full bg-amber-500/10 px-2.5 py-1 text-xs font-semibold text-amber-600 dark:text-amber-400">
|
||||
<Star size={12} weight="fill" />
|
||||
Recommended
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{model.is_enabled ? (
|
||||
<DisableModelModal
|
||||
model={model}
|
||||
availableModels={models}
|
||||
/>
|
||||
) : (
|
||||
<EnableModelButton modelId={model.id} />
|
||||
)}
|
||||
<EditModelModal
|
||||
model={model}
|
||||
providers={providers}
|
||||
creators={creators}
|
||||
/>
|
||||
<DeleteModelModal model={model} availableModels={models} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EnableModelButton({ modelId }: { modelId: string }) {
|
||||
return (
|
||||
<form action={toggleLlmModelAction} className="inline">
|
||||
<input type="hidden" name="model_id" value={modelId} />
|
||||
<input type="hidden" name="is_enabled" value="true" />
|
||||
<Button type="submit" variant="outline" size="small" className="min-w-0">
|
||||
Enable
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/atoms/Table/Table";
|
||||
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
|
||||
import { DeleteProviderModal } from "./DeleteProviderModal";
|
||||
import { EditProviderModal } from "./EditProviderModal";
|
||||
|
||||
export function ProviderList({ providers }: { providers: LlmProvider[] }) {
|
||||
if (!providers.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No providers configured yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Display Name</TableHead>
|
||||
<TableHead>Default Credential</TableHead>
|
||||
<TableHead>Capabilities</TableHead>
|
||||
<TableHead>Models</TableHead>
|
||||
<TableHead className="w-[100px]">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{providers.map((provider) => (
|
||||
<TableRow key={provider.id}>
|
||||
<TableCell className="font-medium">{provider.name}</TableCell>
|
||||
<TableCell>{provider.display_name}</TableCell>
|
||||
<TableCell>
|
||||
{provider.default_credential_provider
|
||||
? `${provider.default_credential_provider} (${provider.default_credential_id ?? "id?"})`
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-muted-foreground">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{provider.supports_tools && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Tools
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_json_output && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
JSON
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_reasoning && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Reasoning
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_parallel_tool && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Parallel Tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-sm">
|
||||
<span
|
||||
className={
|
||||
(provider.models?.length ?? 0) > 0
|
||||
? "text-foreground"
|
||||
: "text-muted-foreground"
|
||||
}
|
||||
>
|
||||
{provider.models?.length ?? 0}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-2">
|
||||
<EditProviderModal provider={provider} />
|
||||
<DeleteProviderModal provider={provider} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { setRecommendedModelAction } from "../actions";
|
||||
import { Star } from "@phosphor-icons/react";
|
||||
|
||||
export function RecommendedModelSelector({ models }: { models: LlmModel[] }) {
|
||||
const router = useRouter();
|
||||
const enabledModels = models.filter((m) => m.is_enabled);
|
||||
const currentRecommended = models.find((m) => m.is_recommended);
|
||||
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>(
|
||||
currentRecommended?.id || "",
|
||||
);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const hasChanges = selectedModelId !== (currentRecommended?.id || "");
|
||||
|
||||
async function handleSave() {
|
||||
if (!selectedModelId) return;
|
||||
|
||||
setIsSaving(true);
|
||||
setError(null);
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.set("model_id", selectedModelId);
|
||||
await setRecommendedModelAction(formData);
|
||||
router.refresh();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to save");
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border border-border bg-card p-4">
|
||||
<div className="mb-3 flex items-center gap-2">
|
||||
<Star size={20} weight="fill" className="text-amber-500" />
|
||||
<h3 className="text-sm font-semibold">Recommended Model</h3>
|
||||
</div>
|
||||
<p className="mb-3 text-xs text-muted-foreground">
|
||||
The recommended model is shown as the default suggestion in model
|
||||
selection dropdowns throughout the platform.
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
<select
|
||||
value={selectedModelId}
|
||||
onChange={(e) => setSelectedModelId(e.target.value)}
|
||||
className="flex-1 rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
disabled={isSaving}
|
||||
>
|
||||
<option value="">-- Select a model --</option>
|
||||
{enabledModels.map((model) => (
|
||||
<option key={model.id} value={model.id}>
|
||||
{model.display_name} ({model.slug})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleSave}
|
||||
disabled={!hasChanges || !selectedModelId || isSaving}
|
||||
>
|
||||
{isSaving ? "Saving..." : "Save"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{error && <p className="mt-2 text-xs text-destructive">{error}</p>}
|
||||
|
||||
{currentRecommended && !hasChanges && (
|
||||
<p className="mt-2 text-xs text-muted-foreground">
|
||||
Currently set to:{" "}
|
||||
<span className="font-medium">{currentRecommended.display_name}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
/**
|
||||
* Server-side data fetching for LLM Registry page.
|
||||
*/
|
||||
|
||||
import {
|
||||
fetchLlmCreators,
|
||||
fetchLlmMigrations,
|
||||
fetchLlmModels,
|
||||
fetchLlmProviders,
|
||||
} from "./actions";
|
||||
|
||||
export async function getLlmRegistryPageData() {
|
||||
// Fetch providers and models (required)
|
||||
const [providersResponse, modelsResponse] = await Promise.all([
|
||||
fetchLlmProviders(),
|
||||
fetchLlmModels(),
|
||||
]);
|
||||
|
||||
// Fetch migrations separately with fallback (table might not exist yet)
|
||||
let migrations: Awaited<ReturnType<typeof fetchLlmMigrations>>["migrations"] =
|
||||
[];
|
||||
try {
|
||||
const migrationsResponse = await fetchLlmMigrations(false);
|
||||
migrations = migrationsResponse.migrations;
|
||||
} catch {
|
||||
// Migrations table might not exist yet - that's ok, just show empty list
|
||||
console.warn("Could not fetch migrations - table may not exist yet");
|
||||
}
|
||||
|
||||
// Fetch creators separately with fallback (table might not exist yet)
|
||||
let creators: Awaited<ReturnType<typeof fetchLlmCreators>>["creators"] = [];
|
||||
try {
|
||||
const creatorsResponse = await fetchLlmCreators();
|
||||
creators = creatorsResponse.creators;
|
||||
} catch {
|
||||
// Creators table might not exist yet - that's ok, just show empty list
|
||||
console.warn("Could not fetch creators - table may not exist yet");
|
||||
}
|
||||
|
||||
return {
|
||||
providers: providersResponse.providers,
|
||||
models: modelsResponse.models,
|
||||
migrations,
|
||||
creators,
|
||||
};
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { getLlmRegistryPageData } from "./getLlmRegistryPage";
|
||||
import { LlmRegistryDashboard } from "./components/LlmRegistryDashboard";
|
||||
|
||||
async function LlmRegistryPage() {
|
||||
const data = await getLlmRegistryPageData();
|
||||
return <LlmRegistryDashboard {...data} />;
|
||||
}
|
||||
|
||||
export default async function AdminLlmRegistryPage() {
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedLlmRegistryPage = await withAdminAccess(LlmRegistryPage);
|
||||
return <ProtectedLlmRegistryPage />;
|
||||
}
|
||||
@@ -9,7 +9,7 @@ export async function GET(request: Request) {
|
||||
const { searchParams, origin } = new URL(request.url);
|
||||
const code = searchParams.get("code");
|
||||
|
||||
let next = "/";
|
||||
let next = "/marketplace";
|
||||
|
||||
if (code) {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
@@ -7,9 +7,8 @@ import { BlockCategoryResponse } from "@/app/api/__generated__/models/blockCateg
|
||||
import { BlockResponse } from "@/app/api/__generated__/models/blockResponse";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useState } from "react";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
|
||||
export const useAllBlockContent = () => {
|
||||
const { toast } = useToast();
|
||||
@@ -94,32 +93,6 @@ export const useAllBlockContent = () => {
|
||||
const isErrorOnLoadingMore = (categoryName: string) =>
|
||||
errorLoadingCategories.has(categoryName);
|
||||
|
||||
// Listen for LLM registry refresh notifications
|
||||
useEffect(() => {
|
||||
const api = new BackendApi();
|
||||
const queryClient = getQueryClient();
|
||||
|
||||
const handleNotification = (notification: any) => {
|
||||
if (
|
||||
notification?.type === "LLM_REGISTRY_REFRESH" ||
|
||||
notification?.event === "registry_updated"
|
||||
) {
|
||||
// Invalidate all block-related queries to force refresh
|
||||
const categoriesQueryKey = getGetV2GetBuilderBlockCategoriesQueryKey();
|
||||
queryClient.invalidateQueries({ queryKey: categoriesQueryKey });
|
||||
}
|
||||
};
|
||||
|
||||
const unsubscribe = api.onWebSocketMessage(
|
||||
"notification",
|
||||
handleNotification,
|
||||
);
|
||||
|
||||
return () => {
|
||||
unsubscribe();
|
||||
};
|
||||
}, []);
|
||||
|
||||
return {
|
||||
data,
|
||||
isLoading,
|
||||
|
||||
@@ -610,11 +610,8 @@ const NodeOneOfDiscriminatorField: FC<{
|
||||
|
||||
return oneOfVariants
|
||||
.map((variant) => {
|
||||
const discProperty = variant.properties?.[discriminatorProperty];
|
||||
const variantDiscValue =
|
||||
discProperty && "const" in discProperty
|
||||
? (discProperty.const as string)
|
||||
: undefined; // NOTE: can discriminators only be strings?
|
||||
const variantDiscValue = variant.properties?.[discriminatorProperty]
|
||||
?.const as string; // NOTE: can discriminators only be strings?
|
||||
|
||||
return {
|
||||
value: variantDiscValue,
|
||||
@@ -1127,47 +1124,9 @@ const NodeStringInput: FC<{
|
||||
displayName,
|
||||
}) => {
|
||||
value ||= schema.default || "";
|
||||
|
||||
// Check if we have options with labels (e.g., LLM model picker)
|
||||
const hasOptions = schema.options && schema.options.length > 0;
|
||||
const hasEnum = schema.enum && schema.enum.length > 0;
|
||||
|
||||
// Helper to get display label for a value
|
||||
const getDisplayLabel = (val: string) => {
|
||||
if (hasOptions) {
|
||||
const option = schema.options!.find((opt) => opt.value === val);
|
||||
return option?.label || beautifyString(val);
|
||||
}
|
||||
return beautifyString(val);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
{hasOptions ? (
|
||||
// Render options with proper labels (used by LLM model picker)
|
||||
<Select
|
||||
defaultValue={value}
|
||||
onValueChange={(newValue) => handleInputChange(selfKey, newValue)}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder={schema.placeholder || displayName}>
|
||||
{value ? getDisplayLabel(value) : undefined}
|
||||
</SelectValue>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{schema.options!.map((option, index) => (
|
||||
<SelectItem
|
||||
key={index}
|
||||
value={option.value}
|
||||
title={option.description}
|
||||
>
|
||||
{option.label || beautifyString(option.value)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
) : hasEnum ? (
|
||||
// Fallback to enum with beautified strings
|
||||
{schema.enum && schema.enum.length > 0 ? (
|
||||
<Select
|
||||
defaultValue={value}
|
||||
onValueChange={(newValue) => handleInputChange(selfKey, newValue)}
|
||||
@@ -1176,8 +1135,8 @@ const NodeStringInput: FC<{
|
||||
<SelectValue placeholder={schema.placeholder || displayName} />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{schema
|
||||
.enum!.filter((option) => option)
|
||||
{schema.enum
|
||||
.filter((option) => option)
|
||||
.map((option, index) => (
|
||||
<SelectItem key={index} value={option}>
|
||||
{beautifyString(option)}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { List } from "@phosphor-icons/react";
|
||||
import React, { useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||
import { ChatLoadingState } from "./components/ChatLoadingState/ChatLoadingState";
|
||||
import { SessionsDrawer } from "./components/SessionsDrawer/SessionsDrawer";
|
||||
import { useChat } from "./useChat";
|
||||
|
||||
export interface ChatProps {
|
||||
className?: string;
|
||||
headerTitle?: React.ReactNode;
|
||||
showHeader?: boolean;
|
||||
showSessionInfo?: boolean;
|
||||
showNewChatButton?: boolean;
|
||||
onNewChat?: () => void;
|
||||
headerActions?: React.ReactNode;
|
||||
}
|
||||
|
||||
export function Chat({
|
||||
className,
|
||||
headerTitle = "AutoGPT Copilot",
|
||||
showHeader = true,
|
||||
showSessionInfo = true,
|
||||
showNewChatButton = true,
|
||||
onNewChat,
|
||||
headerActions,
|
||||
}: ChatProps) {
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
sessionId,
|
||||
createSession,
|
||||
clearSession,
|
||||
loadSession,
|
||||
} = useChat();
|
||||
|
||||
const [isSessionsDrawerOpen, setIsSessionsDrawerOpen] = useState(false);
|
||||
|
||||
const handleNewChat = () => {
|
||||
clearSession();
|
||||
onNewChat?.();
|
||||
};
|
||||
|
||||
const handleSelectSession = async (sessionId: string) => {
|
||||
try {
|
||||
await loadSession(sessionId);
|
||||
} catch (err) {
|
||||
console.error("Failed to load session:", err);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
{/* Header */}
|
||||
{showHeader && (
|
||||
<header className="shrink-0 border-t border-zinc-200 bg-white p-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
aria-label="View sessions"
|
||||
onClick={() => setIsSessionsDrawerOpen(true)}
|
||||
className="flex size-8 items-center justify-center rounded hover:bg-zinc-100"
|
||||
>
|
||||
<List width="1.25rem" height="1.25rem" />
|
||||
</button>
|
||||
{typeof headerTitle === "string" ? (
|
||||
<Text variant="h2" className="text-lg font-semibold">
|
||||
{headerTitle}
|
||||
</Text>
|
||||
) : (
|
||||
headerTitle
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
{showSessionInfo && sessionId && (
|
||||
<>
|
||||
{showNewChatButton && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{headerActions}
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
)}
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="flex min-h-0 flex-1 flex-col overflow-hidden">
|
||||
{/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */}
|
||||
{(isLoading || isCreating || (!sessionId && !error)) && (
|
||||
<ChatLoadingState
|
||||
message={isCreating ? "Creating session..." : "Loading..."}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
{/* Sessions Drawer */}
|
||||
<SessionsDrawer
|
||||
isOpen={isSessionsDrawerOpen}
|
||||
onClose={() => setIsSessionsDrawerOpen(false)}
|
||||
onSelectSession={handleSelectSession}
|
||||
currentSessionId={sessionId}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -21,7 +21,7 @@ export function AuthPromptWidget({
|
||||
message,
|
||||
sessionId,
|
||||
agentInfo,
|
||||
returnUrl = "/copilot/chat",
|
||||
returnUrl = "/chat",
|
||||
className,
|
||||
}: AuthPromptWidgetProps) {
|
||||
const router = useRouter();
|
||||
@@ -0,0 +1,88 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCallback } from "react";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import { ChatInput } from "../ChatInput/ChatInput";
|
||||
import { MessageList } from "../MessageList/MessageList";
|
||||
import { QuickActionsWelcome } from "../QuickActionsWelcome/QuickActionsWelcome";
|
||||
import { useChatContainer } from "./useChatContainer";
|
||||
|
||||
export interface ChatContainerProps {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
className,
|
||||
}: ChatContainerProps) {
|
||||
const { messages, streamingChunks, isStreaming, sendMessage } =
|
||||
useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
});
|
||||
const { capturePageContext } = usePageContext();
|
||||
|
||||
// Wrap sendMessage to automatically capture page context
|
||||
const sendMessageWithContext = useCallback(
|
||||
async (content: string, isUserMessage: boolean = true) => {
|
||||
const context = capturePageContext();
|
||||
await sendMessage(content, isUserMessage, context);
|
||||
},
|
||||
[sendMessage, capturePageContext],
|
||||
);
|
||||
|
||||
const quickActions = [
|
||||
"Find agents for social media management",
|
||||
"Show me agents for content creation",
|
||||
"Help me automate my business",
|
||||
"What can you help me with?",
|
||||
];
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn("flex h-full min-h-0 flex-col", className)}
|
||||
style={{
|
||||
backgroundColor: "#ffffff",
|
||||
backgroundImage:
|
||||
"radial-gradient(#e5e5e5 0.5px, transparent 0.5px), radial-gradient(#e5e5e5 0.5px, #ffffff 0.5px)",
|
||||
backgroundSize: "20px 20px",
|
||||
backgroundPosition: "0 0, 10px 10px",
|
||||
}}
|
||||
>
|
||||
{/* Messages or Welcome Screen */}
|
||||
<div className="flex min-h-0 flex-1 flex-col overflow-hidden pb-24">
|
||||
{messages.length === 0 ? (
|
||||
<QuickActionsWelcome
|
||||
title="Welcome to AutoGPT Copilot"
|
||||
description="Start a conversation to discover and run AI agents."
|
||||
actions={quickActions}
|
||||
onActionClick={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
/>
|
||||
) : (
|
||||
<MessageList
|
||||
messages={messages}
|
||||
streamingChunks={streamingChunks}
|
||||
isStreaming={isStreaming}
|
||||
onSendMessage={sendMessageWithContext}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Input - Always visible */}
|
||||
<div className="fixed bottom-0 left-0 right-0 z-50 border-t border-zinc-200 bg-white p-4">
|
||||
<ChatInput
|
||||
onSend={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
placeholder={
|
||||
sessionId ? "Type your message..." : "Creating session..."
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { toast } from "sonner";
|
||||
import { StreamChunk } from "../../useChatStream";
|
||||
import type { HandlerDependencies } from "./handlers";
|
||||
import type { HandlerDependencies } from "./useChatContainer.handlers";
|
||||
import {
|
||||
handleError,
|
||||
handleLoginNeeded,
|
||||
@@ -9,30 +9,12 @@ import {
|
||||
handleTextEnded,
|
||||
handleToolCallStart,
|
||||
handleToolResponse,
|
||||
isRegionBlockedError,
|
||||
} from "./handlers";
|
||||
} from "./useChatContainer.handlers";
|
||||
|
||||
export function createStreamEventDispatcher(
|
||||
deps: HandlerDependencies,
|
||||
): (chunk: StreamChunk) => void {
|
||||
return function dispatchStreamEvent(chunk: StreamChunk): void {
|
||||
if (
|
||||
chunk.type === "text_chunk" ||
|
||||
chunk.type === "tool_call_start" ||
|
||||
chunk.type === "tool_response" ||
|
||||
chunk.type === "login_needed" ||
|
||||
chunk.type === "need_login" ||
|
||||
chunk.type === "error"
|
||||
) {
|
||||
if (!deps.hasResponseRef.current) {
|
||||
console.info("[ChatStream] First response chunk:", {
|
||||
type: chunk.type,
|
||||
sessionId: deps.sessionId,
|
||||
});
|
||||
}
|
||||
deps.hasResponseRef.current = true;
|
||||
}
|
||||
|
||||
switch (chunk.type) {
|
||||
case "text_chunk":
|
||||
handleTextChunk(chunk, deps);
|
||||
@@ -56,23 +38,15 @@ export function createStreamEventDispatcher(
|
||||
break;
|
||||
|
||||
case "stream_end":
|
||||
console.info("[ChatStream] Stream ended:", {
|
||||
sessionId: deps.sessionId,
|
||||
hasResponse: deps.hasResponseRef.current,
|
||||
chunkCount: deps.streamingChunksRef.current.length,
|
||||
});
|
||||
handleStreamEnd(chunk, deps);
|
||||
break;
|
||||
|
||||
case "error":
|
||||
const isRegionBlocked = isRegionBlockedError(chunk);
|
||||
handleError(chunk, deps);
|
||||
// Show toast at dispatcher level to avoid circular dependencies
|
||||
if (!isRegionBlocked) {
|
||||
toast.error("Chat Error", {
|
||||
description: chunk.message || chunk.content || "An error occurred",
|
||||
});
|
||||
}
|
||||
toast.error("Chat Error", {
|
||||
description: chunk.message || chunk.content || "An error occurred",
|
||||
});
|
||||
break;
|
||||
|
||||
case "usage":
|
||||
@@ -1,33 +1,6 @@
|
||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
|
||||
export function hasSentInitialPrompt(sessionId: string): boolean {
|
||||
try {
|
||||
const sent = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
return sent[sessionId] === true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function markInitialPromptSent(sessionId: string): void {
|
||||
try {
|
||||
const sent = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
sent[sessionId] = true;
|
||||
sessionStorage.set(
|
||||
SessionKey.CHAT_SENT_INITIAL_PROMPTS,
|
||||
JSON.stringify(sent),
|
||||
);
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
export function removePageContext(content: string): string {
|
||||
// Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats)
|
||||
let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, "");
|
||||
@@ -234,22 +207,12 @@ export function parseToolResponse(
|
||||
if (responseType === "setup_requirements") {
|
||||
return null;
|
||||
}
|
||||
if (responseType === "understanding_updated") {
|
||||
return {
|
||||
type: "tool_response",
|
||||
toolId,
|
||||
toolName,
|
||||
result: (parsedResult || result) as ToolResult,
|
||||
success: true,
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
type: "tool_response",
|
||||
toolId,
|
||||
toolName,
|
||||
result: parsedResult ? (parsedResult as ToolResult) : result,
|
||||
result,
|
||||
success: true,
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
@@ -7,30 +7,15 @@ import {
|
||||
parseToolResponse,
|
||||
} from "./helpers";
|
||||
|
||||
function isToolCallMessage(
|
||||
message: ChatMessageData,
|
||||
): message is Extract<ChatMessageData, { type: "tool_call" }> {
|
||||
return message.type === "tool_call";
|
||||
}
|
||||
|
||||
export interface HandlerDependencies {
|
||||
setHasTextChunks: Dispatch<SetStateAction<boolean>>;
|
||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||
streamingChunksRef: MutableRefObject<string[]>;
|
||||
hasResponseRef: MutableRefObject<boolean>;
|
||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
}
|
||||
|
||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||
if (chunk.code === "MODEL_NOT_AVAILABLE_REGION") return true;
|
||||
const message = chunk.message || chunk.content;
|
||||
if (typeof message !== "string") return false;
|
||||
return message.toLowerCase().includes("not available in your region");
|
||||
}
|
||||
|
||||
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
if (!chunk.content) return;
|
||||
deps.setHasTextChunks(true);
|
||||
@@ -45,17 +30,16 @@ export function handleTextEnded(
|
||||
_chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.log("[Text Ended] Saving streamed text as assistant message");
|
||||
const completedText = deps.streamingChunksRef.current.join("");
|
||||
if (completedText.trim()) {
|
||||
deps.setMessages((prev) => {
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedText,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
return [...prev, assistantMessage];
|
||||
});
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedText,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||
}
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
@@ -66,45 +50,30 @@ export function handleToolCallStart(
|
||||
chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
||||
const toolCallMessage: ChatMessageData = {
|
||||
type: "tool_call",
|
||||
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
||||
toolName: chunk.tool_name || "Executing",
|
||||
toolName: chunk.tool_name || "Executing...",
|
||||
arguments: chunk.arguments || {},
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
function updateToolCallMessages(prev: ChatMessageData[]) {
|
||||
const existingIndex = prev.findIndex(function findToolCallIndex(msg) {
|
||||
return isToolCallMessage(msg) && msg.toolId === toolCallMessage.toolId;
|
||||
});
|
||||
if (existingIndex === -1) {
|
||||
return [...prev, toolCallMessage];
|
||||
}
|
||||
const nextMessages = [...prev];
|
||||
const existing = nextMessages[existingIndex];
|
||||
if (!isToolCallMessage(existing)) return prev;
|
||||
const nextArguments =
|
||||
toolCallMessage.arguments &&
|
||||
Object.keys(toolCallMessage.arguments).length > 0
|
||||
? toolCallMessage.arguments
|
||||
: existing.arguments;
|
||||
nextMessages[existingIndex] = {
|
||||
...existing,
|
||||
toolName: toolCallMessage.toolName || existing.toolName,
|
||||
arguments: nextArguments,
|
||||
timestamp: toolCallMessage.timestamp,
|
||||
};
|
||||
return nextMessages;
|
||||
}
|
||||
|
||||
deps.setMessages(updateToolCallMessages);
|
||||
deps.setMessages((prev) => [...prev, toolCallMessage]);
|
||||
console.log("[Tool Call Start]", {
|
||||
toolId: toolCallMessage.toolId,
|
||||
toolName: toolCallMessage.toolName,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
}
|
||||
|
||||
export function handleToolResponse(
|
||||
chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.log("[Tool Response] Received:", {
|
||||
toolId: chunk.tool_id,
|
||||
toolName: chunk.tool_name,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
let toolName = chunk.tool_name || "unknown";
|
||||
if (!chunk.tool_name || chunk.tool_name === "unknown") {
|
||||
deps.setMessages((prev) => {
|
||||
@@ -158,15 +127,22 @@ export function handleToolResponse(
|
||||
const toolCallIndex = prev.findIndex(
|
||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||
);
|
||||
const hasResponse = prev.some(
|
||||
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
|
||||
);
|
||||
if (hasResponse) return prev;
|
||||
if (toolCallIndex !== -1) {
|
||||
const newMessages = [...prev];
|
||||
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
||||
newMessages[toolCallIndex] = responseMessage;
|
||||
console.log(
|
||||
"[Tool Response] Replaced tool_call with matching tool_id:",
|
||||
chunk.tool_id,
|
||||
"at index:",
|
||||
toolCallIndex,
|
||||
);
|
||||
return newMessages;
|
||||
}
|
||||
console.warn(
|
||||
"[Tool Response] No tool_call found with tool_id:",
|
||||
chunk.tool_id,
|
||||
"appending instead",
|
||||
);
|
||||
return [...prev, responseMessage];
|
||||
});
|
||||
}
|
||||
@@ -191,38 +167,55 @@ export function handleStreamEnd(
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
const completedContent = deps.streamingChunksRef.current.join("");
|
||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||
deps.setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: "No response received. Please try again.",
|
||||
timestamp: new Date(),
|
||||
},
|
||||
]);
|
||||
}
|
||||
// Only save message if there are uncommitted chunks
|
||||
// (text_ended already saved if there were tool calls)
|
||||
if (completedContent.trim()) {
|
||||
console.log(
|
||||
"[Stream End] Saving remaining streamed text as assistant message",
|
||||
);
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedContent,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||
deps.setMessages((prev) => {
|
||||
const updated = [...prev, assistantMessage];
|
||||
console.log("[Stream End] Final state:", {
|
||||
localMessages: updated.map((m) => ({
|
||||
type: m.type,
|
||||
...(m.type === "message" && {
|
||||
role: m.role,
|
||||
contentLength: m.content.length,
|
||||
}),
|
||||
...(m.type === "tool_call" && {
|
||||
toolId: m.toolId,
|
||||
toolName: m.toolName,
|
||||
}),
|
||||
...(m.type === "tool_response" && {
|
||||
toolId: m.toolId,
|
||||
toolName: m.toolName,
|
||||
success: m.success,
|
||||
}),
|
||||
})),
|
||||
streamingChunks: deps.streamingChunksRef.current,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
} else {
|
||||
console.log("[Stream End] No uncommitted chunks, message already saved");
|
||||
}
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setIsStreamingInitiated(false);
|
||||
console.log("[Stream End] Stream complete, messages in local state");
|
||||
}
|
||||
|
||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||
console.error("Stream error:", errorMessage);
|
||||
if (isRegionBlockedError(chunk)) {
|
||||
deps.setIsRegionBlockedModalOpen(true);
|
||||
}
|
||||
deps.setIsStreamingInitiated(false);
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setStreamingChunks([]);
|
||||
@@ -1,17 +1,14 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useCallback, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "../../useChatStream";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
import {
|
||||
createUserMessage,
|
||||
filterAuthMessages,
|
||||
hasSentInitialPrompt,
|
||||
isToolCallArray,
|
||||
isValidMessage,
|
||||
markInitialPromptSent,
|
||||
parseToolResponse,
|
||||
removePageContext,
|
||||
} from "./helpers";
|
||||
@@ -19,45 +16,20 @@ import {
|
||||
interface Args {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
initialPrompt?: string;
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
}: Args) {
|
||||
export function useChatContainer({ sessionId, initialMessages }: Args) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
const [hasTextChunks, setHasTextChunks] = useState(false);
|
||||
const [isStreamingInitiated, setIsStreamingInitiated] = useState(false);
|
||||
const [isRegionBlockedModalOpen, setIsRegionBlockedModalOpen] =
|
||||
useState(false);
|
||||
const hasResponseRef = useRef(false);
|
||||
const streamingChunksRef = useRef<string[]>([]);
|
||||
const previousSessionIdRef = useRef<string | null>(null);
|
||||
const {
|
||||
error,
|
||||
sendMessage: sendStreamMessage,
|
||||
stopStreaming,
|
||||
} = useChatStream();
|
||||
const { error, sendMessage: sendStreamMessage } = useChatStream();
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
|
||||
useEffect(() => {
|
||||
if (sessionId !== previousSessionIdRef.current) {
|
||||
stopStreaming(previousSessionIdRef.current ?? undefined, true);
|
||||
previousSessionIdRef.current = sessionId;
|
||||
setMessages([]);
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
hasResponseRef.current = false;
|
||||
}
|
||||
}, [sessionId, stopStreaming]);
|
||||
|
||||
const allMessages = useMemo(() => {
|
||||
const processedInitialMessages: ChatMessageData[] = [];
|
||||
// Map to track tool calls by their ID so we can look up tool names for tool responses
|
||||
const toolCallMap = new Map<string, string>();
|
||||
|
||||
for (const msg of initialMessages) {
|
||||
@@ -73,9 +45,13 @@ export function useChatContainer({
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined;
|
||||
|
||||
// Remove page context from user messages when loading existing sessions
|
||||
if (role === "user") {
|
||||
content = removePageContext(content);
|
||||
if (!content.trim()) continue;
|
||||
// Skip user messages that become empty after removing page context
|
||||
if (!content.trim()) {
|
||||
continue;
|
||||
}
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "user",
|
||||
@@ -85,15 +61,19 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle assistant messages first (before tool messages) to build tool call map
|
||||
if (role === "assistant") {
|
||||
// Strip <thinking> tags from content
|
||||
content = content
|
||||
.replace(/<thinking>[\s\S]*?<\/thinking>/gi, "")
|
||||
.trim();
|
||||
|
||||
// If assistant has tool calls, create tool_call messages for each
|
||||
if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) {
|
||||
for (const toolCall of toolCalls) {
|
||||
const toolName = toolCall.function.name;
|
||||
const toolId = toolCall.id;
|
||||
// Store tool name for later lookup
|
||||
toolCallMap.set(toolId, toolName);
|
||||
|
||||
try {
|
||||
@@ -116,6 +96,7 @@ export function useChatContainer({
|
||||
});
|
||||
}
|
||||
}
|
||||
// Only add assistant message if there's content after stripping thinking tags
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
@@ -125,6 +106,7 @@ export function useChatContainer({
|
||||
});
|
||||
}
|
||||
} else if (content.trim()) {
|
||||
// Assistant message without tool calls, but with content
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
@@ -135,6 +117,7 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle tool messages - look up tool name from tool call map
|
||||
if (role === "tool") {
|
||||
const toolCallId = (msg.tool_call_id as string) || "";
|
||||
const toolName = toolCallMap.get(toolCallId) || "unknown";
|
||||
@@ -150,6 +133,7 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle other message types (system, etc.)
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
@@ -170,10 +154,9 @@ export function useChatContainer({
|
||||
context?: { url: string; content: string },
|
||||
) {
|
||||
if (!sessionId) {
|
||||
console.error("[useChatContainer] Cannot send message: no session ID");
|
||||
console.error("Cannot send message: no session ID");
|
||||
return;
|
||||
}
|
||||
setIsRegionBlockedModalOpen(false);
|
||||
if (isUserMessage) {
|
||||
const userMessage = createUserMessage(content);
|
||||
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
||||
@@ -184,19 +167,14 @@ export function useChatContainer({
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(true);
|
||||
hasResponseRef.current = false;
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
});
|
||||
|
||||
try {
|
||||
await sendStreamMessage(
|
||||
sessionId,
|
||||
@@ -206,12 +184,8 @@ export function useChatContainer({
|
||||
context,
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("[useChatContainer] Failed to send message:", err);
|
||||
console.error("Failed to send message:", err);
|
||||
setIsStreamingInitiated(false);
|
||||
|
||||
// Don't show error toast for AbortError (expected during cleanup)
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to send message";
|
||||
toast.error("Failed to send message", {
|
||||
@@ -222,63 +196,11 @@ export function useChatContainer({
|
||||
[sessionId, sendStreamMessage],
|
||||
);
|
||||
|
||||
const handleStopStreaming = useCallback(() => {
|
||||
stopStreaming();
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
}, [stopStreaming]);
|
||||
|
||||
const { capturePageContext } = usePageContext();
|
||||
|
||||
// Send initial prompt if provided (for new sessions from homepage)
|
||||
useEffect(
|
||||
function handleInitialPrompt() {
|
||||
if (!initialPrompt || !sessionId) return;
|
||||
if (initialMessages.length > 0) return;
|
||||
if (hasSentInitialPrompt(sessionId)) return;
|
||||
|
||||
markInitialPromptSent(sessionId);
|
||||
const context = capturePageContext();
|
||||
sendMessage(initialPrompt, true, context);
|
||||
},
|
||||
[
|
||||
initialPrompt,
|
||||
sessionId,
|
||||
initialMessages.length,
|
||||
sendMessage,
|
||||
capturePageContext,
|
||||
],
|
||||
);
|
||||
|
||||
async function sendMessageWithContext(
|
||||
content: string,
|
||||
isUserMessage: boolean = true,
|
||||
) {
|
||||
const context = capturePageContext();
|
||||
await sendMessage(content, isUserMessage, context);
|
||||
}
|
||||
|
||||
function handleRegionModalOpenChange(open: boolean) {
|
||||
setIsRegionBlockedModalOpen(open);
|
||||
}
|
||||
|
||||
function handleRegionModalClose() {
|
||||
setIsRegionBlockedModalOpen(false);
|
||||
}
|
||||
|
||||
return {
|
||||
messages: allMessages,
|
||||
streamingChunks,
|
||||
isStreaming,
|
||||
error,
|
||||
isRegionBlockedModalOpen,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sendMessageWithContext,
|
||||
handleRegionModalOpenChange,
|
||||
handleRegionModalClose,
|
||||
sendMessage,
|
||||
stopStreaming: handleStopStreaming,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ArrowUpIcon } from "@phosphor-icons/react";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
|
||||
export interface ChatInputProps {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
placeholder = "Type your message...",
|
||||
className,
|
||||
}: ChatInputProps) {
|
||||
const inputId = "chat-input";
|
||||
const { value, setValue, handleKeyDown, handleSend } = useChatInput({
|
||||
onSend,
|
||||
disabled,
|
||||
maxRows: 5,
|
||||
inputId,
|
||||
});
|
||||
|
||||
return (
|
||||
<div className={cn("relative flex-1", className)}>
|
||||
<Input
|
||||
id={inputId}
|
||||
label="Chat message input"
|
||||
hideLabel
|
||||
type="textarea"
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
rows={1}
|
||||
wrapperClassName="mb-0 relative"
|
||||
className="pr-12"
|
||||
/>
|
||||
<span id="chat-input-hint" className="sr-only">
|
||||
Press Enter to send, Shift+Enter for new line
|
||||
</span>
|
||||
|
||||
<button
|
||||
onClick={handleSend}
|
||||
disabled={disabled || !value.trim()}
|
||||
className={cn(
|
||||
"absolute right-3 top-1/2 flex h-8 w-8 -translate-y-1/2 items-center justify-center rounded-full",
|
||||
"border border-zinc-800 bg-zinc-800 text-white",
|
||||
"hover:border-zinc-900 hover:bg-zinc-900",
|
||||
"disabled:border-zinc-200 disabled:bg-zinc-200 disabled:text-white disabled:opacity-50",
|
||||
"transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-neutral-950",
|
||||
"disabled:pointer-events-none",
|
||||
)}
|
||||
aria-label="Send message"
|
||||
>
|
||||
<ArrowUpIcon className="h-3 w-3" weight="bold" />
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
import { KeyboardEvent, useCallback, useEffect, useState } from "react";
|
||||
|
||||
interface UseChatInputArgs {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
maxRows?: number;
|
||||
inputId?: string;
|
||||
}
|
||||
|
||||
export function useChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
maxRows = 5,
|
||||
inputId = "chat-input",
|
||||
}: UseChatInputArgs) {
|
||||
const [value, setValue] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (!textarea) return;
|
||||
textarea.style.height = "auto";
|
||||
const lineHeight = parseInt(
|
||||
window.getComputedStyle(textarea).lineHeight,
|
||||
10,
|
||||
);
|
||||
const maxHeight = lineHeight * maxRows;
|
||||
const newHeight = Math.min(textarea.scrollHeight, maxHeight);
|
||||
textarea.style.height = `${newHeight}px`;
|
||||
textarea.style.overflowY =
|
||||
textarea.scrollHeight > maxHeight ? "auto" : "hidden";
|
||||
}, [value, maxRows, inputId]);
|
||||
|
||||
const handleSend = useCallback(() => {
|
||||
if (disabled || !value.trim()) return;
|
||||
onSend(value.trim());
|
||||
setValue("");
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (textarea) {
|
||||
textarea.style.height = "auto";
|
||||
}
|
||||
}, [value, onSend, disabled, inputId]);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSend();
|
||||
}
|
||||
// Shift+Enter allows default behavior (new line) - no need to handle explicitly
|
||||
},
|
||||
[handleSend],
|
||||
);
|
||||
|
||||
return {
|
||||
value,
|
||||
setValue,
|
||||
handleKeyDown,
|
||||
handleSend,
|
||||
};
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user