fix(backend/llm-admin-api): fix usage count, slug mismatch, and recommended model guard

db_write.py:
- Add _node_model_value() helper: strips provider prefix from slug so
  queries match the enum value stored in AgentNode.constantInput
  (e.g. 'openai/gpt-4o' -> 'gpt-4o')
- get_model_usage: use _node_model_value(slug) — was always returning 0
- toggle_model_with_migration: use _node_model_value for both the SELECT
  (finding nodes) and the UPDATE (writing the replacement value)

admin_routes.py:
- toggle_model: guard against disabling the recommended model with 400
  'Cannot disable the recommended model. Change it first.'
This commit is contained in:
Bentlybro
2026-04-08 15:54:54 +01:00
parent 5352dc9778
commit da533089d3
2 changed files with 27 additions and 5 deletions

View File

@@ -302,16 +302,26 @@ async def toggle_model(
status_code=404, detail=f"Model with slug '{slug}' not found"
)
is_enabling = request.get("is_enabled", True)
if not is_enabling and existing.isRecommended:
raise HTTPException(
status_code=400,
detail=(
"Cannot disable the recommended model. "
"Change the recommended model before disabling this one."
),
)
result = await db_write.toggle_model_with_migration(
model_id=existing.id,
is_enabled=request.get("is_enabled", True),
is_enabled=is_enabling,
migrate_to_slug=request.get("migrate_to_slug"),
migration_reason=request.get("migration_reason"),
custom_credit_cost=request.get("custom_credit_cost"),
)
await db_write.refresh_runtime_caches()
logger.info(
f"Toggled model '{slug}' enabled={request.get('is_enabled')} "
f"Toggled model '{slug}' enabled={is_enabling} "
f"(migrated {result['nodes_migrated']} nodes)"
)
return result

View File

@@ -14,6 +14,15 @@ from backend.data.db import transaction
logger = logging.getLogger(__name__)
def _node_model_value(slug: str) -> str:
"""Extract the model value stored in AgentNode.constantInput from a registry slug.
Registry slugs are formatted as 'provider/model-name' (e.g. 'openai/gpt-4o').
The LLM block stores only the model-name part (e.g. 'gpt-4o') in constantInput.
"""
return slug.split("/", 1)[-1] if "/" in slug else slug
def _build_provider_data(
name: str,
display_name: str,
@@ -293,13 +302,14 @@ async def get_model_usage(slug: str) -> dict[str, Any]:
"""Get usage count for a model — how many AgentNodes reference it."""
import prisma as prisma_module
model_value = _node_model_value(slug)
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
slug,
model_value,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return {"model_slug": slug, "node_count": node_count}
@@ -337,6 +347,8 @@ async def toggle_model_with_migration(
f"Please enable it before using it as a replacement."
)
source_value = _node_model_value(model.slug)
target_value = _node_model_value(migrate_to_slug)
node_ids_result = await tx.query_raw(
"""
SELECT id
@@ -344,7 +356,7 @@ async def toggle_model_with_migration(
WHERE "constantInput"::jsonb->>'model' = $1
FOR UPDATE
""",
model.slug,
source_value,
)
migrated_node_ids = (
[row["id"] for row in node_ids_result] if node_ids_result else []
@@ -365,7 +377,7 @@ async def toggle_model_with_migration(
SELECT jsonb_array_elements_text($2::jsonb)
)
""",
migrate_to_slug,
target_value,
node_ids_json,
)