diff --git a/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py b/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py index 8913c00077..1df468492a 100644 --- a/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py +++ b/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py @@ -302,16 +302,26 @@ async def toggle_model( status_code=404, detail=f"Model with slug '{slug}' not found" ) + is_enabling = request.get("is_enabled", True) + if not is_enabling and existing.isRecommended: + raise HTTPException( + status_code=400, + detail=( + "Cannot disable the recommended model. " + "Change the recommended model before disabling this one." + ), + ) + result = await db_write.toggle_model_with_migration( model_id=existing.id, - is_enabled=request.get("is_enabled", True), + is_enabled=is_enabling, migrate_to_slug=request.get("migrate_to_slug"), migration_reason=request.get("migration_reason"), custom_credit_cost=request.get("custom_credit_cost"), ) await db_write.refresh_runtime_caches() logger.info( - f"Toggled model '{slug}' enabled={request.get('is_enabled')} " + f"Toggled model '{slug}' enabled={is_enabling} " f"(migrated {result['nodes_migrated']} nodes)" ) return result diff --git a/autogpt_platform/backend/backend/server/v2/llm/db_write.py b/autogpt_platform/backend/backend/server/v2/llm/db_write.py index 87b0f9ab4a..1d4604b7e1 100644 --- a/autogpt_platform/backend/backend/server/v2/llm/db_write.py +++ b/autogpt_platform/backend/backend/server/v2/llm/db_write.py @@ -14,6 +14,15 @@ from backend.data.db import transaction logger = logging.getLogger(__name__) +def _node_model_value(slug: str) -> str: + """Extract the model value stored in AgentNode.constantInput from a registry slug. + + Registry slugs are formatted as 'provider/model-name' (e.g. 'openai/gpt-4o'). + The LLM block stores only the model-name part (e.g. 'gpt-4o') in constantInput. + """ + return slug.split("/", 1)[-1] if "/" in slug else slug + + def _build_provider_data( name: str, display_name: str, @@ -293,13 +302,14 @@ async def get_model_usage(slug: str) -> dict[str, Any]: """Get usage count for a model — how many AgentNodes reference it.""" import prisma as prisma_module + model_value = _node_model_value(slug) count_result = await prisma_module.get_client().query_raw( """ SELECT COUNT(*) as count FROM "AgentNode" WHERE "constantInput"::jsonb->>'model' = $1 """, - slug, + model_value, ) node_count = int(count_result[0]["count"]) if count_result else 0 return {"model_slug": slug, "node_count": node_count} @@ -337,6 +347,8 @@ async def toggle_model_with_migration( f"Please enable it before using it as a replacement." ) + source_value = _node_model_value(model.slug) + target_value = _node_model_value(migrate_to_slug) node_ids_result = await tx.query_raw( """ SELECT id @@ -344,7 +356,7 @@ async def toggle_model_with_migration( WHERE "constantInput"::jsonb->>'model' = $1 FOR UPDATE """, - model.slug, + source_value, ) migrated_node_ids = ( [row["id"] for row in node_ids_result] if node_ids_result else [] @@ -365,7 +377,7 @@ async def toggle_model_with_migration( SELECT jsonb_array_elements_text($2::jsonb) ) """, - migrate_to_slug, + target_value, node_ids_json, )