diff --git a/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py b/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py index 8f75ff839e..3ad887c89e 100644 --- a/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py +++ b/autogpt_platform/backend/backend/server/v2/llm/admin_routes.py @@ -216,18 +216,19 @@ async def update_model( @router.delete( "/llm/models/{slug:path}", - status_code=status.HTTP_204_NO_CONTENT, dependencies=[Security(autogpt_libs.auth.requires_admin_user)], ) async def delete_model( slug: str, -) -> None: - """Delete an LLM model. + replacement_model_slug: str | None = None, +) -> dict[str, Any]: + """Delete an LLM model with optional migration. - Requires admin authentication. + If workflows are using this model and no replacement_model_slug is given, + returns 400 with the node count. Provide replacement_model_slug to migrate + affected nodes before deletion. """ try: - # Find model by slug first to get ID import prisma.models existing = await prisma.models.LlmModel.prisma().find_unique( @@ -238,9 +239,15 @@ async def delete_model( status_code=404, detail=f"Model with slug '{slug}' not found" ) - await db_write.delete_model(model_id=existing.id) + result = await db_write.delete_model( + model_id=existing.id, + replacement_model_slug=replacement_model_slug, + ) await db_write.refresh_runtime_caches() - logger.info(f"Deleted model '{slug}' (id: {existing.id})") + logger.info( + f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)" + ) + return result except ValueError as e: logger.warning(f"Model deletion validation failed: {e}") raise HTTPException(status_code=400, detail=str(e)) @@ -249,6 +256,117 @@ async def delete_model( raise HTTPException(status_code=500, detail="Failed to delete model") +@router.get( + "/llm/models/{slug:path}/usage", + dependencies=[Security(autogpt_libs.auth.requires_admin_user)], +) +async def get_model_usage(slug: str) -> dict[str, Any]: + """Get usage count for a model — how many workflow nodes reference it.""" + try: + return await db_write.get_model_usage(slug) + except Exception as e: + logger.exception(f"Failed to get model usage: {e}") + raise HTTPException(status_code=500, detail="Failed to get model usage") + + +@router.post( + "/llm/models/{slug:path}/toggle", + dependencies=[Security(autogpt_libs.auth.requires_admin_user)], +) +async def toggle_model( + slug: str, + request: dict[str, Any], +) -> dict[str, Any]: + """Toggle a model's enabled status with optional migration when disabling. + + Body params: + is_enabled: bool + migrate_to_slug: optional str + migration_reason: optional str + custom_credit_cost: optional int + """ + try: + import prisma.models + + existing = await prisma.models.LlmModel.prisma().find_unique( + where={"slug": slug} + ) + if not existing: + raise HTTPException( + status_code=404, detail=f"Model with slug '{slug}' not found" + ) + + result = await db_write.toggle_model_with_migration( + model_id=existing.id, + is_enabled=request.get("is_enabled", True), + migrate_to_slug=request.get("migrate_to_slug"), + migration_reason=request.get("migration_reason"), + custom_credit_cost=request.get("custom_credit_cost"), + ) + await db_write.refresh_runtime_caches() + logger.info( + f"Toggled model '{slug}' enabled={request.get('is_enabled')} " + f"(migrated {result['nodes_migrated']} nodes)" + ) + return result + except ValueError as e: + logger.warning(f"Model toggle failed: {e}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.exception(f"Failed to toggle model: {e}") + raise HTTPException(status_code=500, detail="Failed to toggle model") + + +@router.get( + "/llm/migrations", + dependencies=[Security(autogpt_libs.auth.requires_admin_user)], +) +async def list_migrations( + include_reverted: bool = False, +) -> dict[str, Any]: + """List model migrations.""" + try: + migrations = await db_write.list_migrations( + include_reverted=include_reverted + ) + return {"migrations": migrations} + except Exception as e: + logger.exception(f"Failed to list migrations: {e}") + raise HTTPException( + status_code=500, detail="Failed to list migrations" + ) + + +@router.post( + "/llm/migrations/{migration_id}/revert", + dependencies=[Security(autogpt_libs.auth.requires_admin_user)], +) +async def revert_migration( + migration_id: str, + re_enable_source_model: bool = True, +) -> dict[str, Any]: + """Revert a model migration, restoring affected nodes.""" + try: + result = await db_write.revert_migration( + migration_id=migration_id, + re_enable_source_model=re_enable_source_model, + ) + await db_write.refresh_runtime_caches() + logger.info( + f"Reverted migration {migration_id}: " + f"{result['nodes_reverted']} nodes restored" + ) + return result + except ValueError as e: + logger.warning(f"Migration revert failed: {e}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.exception(f"Failed to revert migration: {e}") + raise HTTPException( + status_code=500, detail="Failed to revert migration" + ) + + @router.post( "/llm/providers", status_code=status.HTTP_201_CREATED, diff --git a/autogpt_platform/backend/backend/server/v2/llm/db_write.py b/autogpt_platform/backend/backend/server/v2/llm/db_write.py index a2e082c0f2..d30b685d5a 100644 --- a/autogpt_platform/backend/backend/server/v2/llm/db_write.py +++ b/autogpt_platform/backend/backend/server/v2/llm/db_write.py @@ -1,11 +1,17 @@ """Database write operations for LLM registry admin API.""" +import json +import logging +from datetime import datetime, timezone from typing import Any import prisma import prisma.models from backend.data import llm_registry +from backend.data.db import transaction + +logger = logging.getLogger(__name__) def _build_provider_data( @@ -270,25 +276,291 @@ async def update_model( return model -async def delete_model(model_id: str) -> bool: - """Delete an LLM model. +async def get_model_usage(slug: str) -> dict[str, Any]: + """Get usage count for a model — how many AgentNodes reference it.""" + import prisma as prisma_module - Note: This should check if any workflows are using this model first. - For now, we'll allow deletion and rely on FK constraints. - """ - # Check if model exists - model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id}) + count_result = await prisma_module.get_client().query_raw( + """ + SELECT COUNT(*) as count + FROM "AgentNode" + WHERE "constantInput"::jsonb->>'model' = $1 + """, + slug, + ) + node_count = int(count_result[0]["count"]) if count_result else 0 + return {"model_slug": slug, "node_count": node_count} + + +async def toggle_model_with_migration( + model_id: str, + is_enabled: bool, + migrate_to_slug: str | None = None, + migration_reason: str | None = None, + custom_credit_cost: int | None = None, +) -> dict[str, Any]: + """Toggle a model's enabled status, optionally migrating workflows when disabling.""" + model = await prisma.models.LlmModel.prisma().find_unique( + where={"id": model_id}, include={"Costs": True} + ) if not model: raise ValueError(f"Model with id '{model_id}' not found") - await prisma.models.LlmModel.prisma().delete(where={"id": model_id}) - return True + nodes_migrated = 0 + migration_id: str | None = None + + if not is_enabled and migrate_to_slug: + async with transaction() as tx: + replacement = await tx.llmmodel.find_unique( + where={"slug": migrate_to_slug} + ) + if not replacement: + raise ValueError( + f"Replacement model '{migrate_to_slug}' not found" + ) + if not replacement.isEnabled: + raise ValueError( + f"Replacement model '{migrate_to_slug}' is disabled. " + f"Please enable it before using it as a replacement." + ) + + node_ids_result = await tx.query_raw( + """ + SELECT id + FROM "AgentNode" + WHERE "constantInput"::jsonb->>'model' = $1 + FOR UPDATE + """, + model.slug, + ) + migrated_node_ids = ( + [row["id"] for row in node_ids_result] if node_ids_result else [] + ) + nodes_migrated = len(migrated_node_ids) + + if nodes_migrated > 0: + node_ids_json = json.dumps(migrated_node_ids) + await tx.execute_raw( + """ + UPDATE "AgentNode" + SET "constantInput" = JSONB_SET( + "constantInput"::jsonb, + '{model}', + to_jsonb($1::text) + ) + WHERE id::text IN ( + SELECT jsonb_array_elements_text($2::jsonb) + ) + """, + migrate_to_slug, + node_ids_json, + ) + + await tx.llmmodel.update( + where={"id": model_id}, + data={"isEnabled": is_enabled}, + ) + + if nodes_migrated > 0: + migration_record = await tx.llmmodelmigration.create( + data={ + "sourceModelSlug": model.slug, + "targetModelSlug": migrate_to_slug, + "reason": migration_reason, + "migratedNodeIds": json.dumps(migrated_node_ids), + "nodeCount": nodes_migrated, + "customCreditCost": custom_credit_cost, + } + ) + migration_id = migration_record.id + else: + await prisma.models.LlmModel.prisma().update( + where={"id": model_id}, + data={"isEnabled": is_enabled}, + ) + + return { + "nodes_migrated": nodes_migrated, + "migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None, + "migration_id": migration_id, + } + + +async def delete_model( + model_id: str, replacement_model_slug: str | None = None +) -> dict[str, Any]: + """Delete an LLM model, optionally migrating affected AgentNodes first. + + If workflows are using this model and no replacement is given, raises ValueError. + If replacement is given, atomically migrates all affected nodes then deletes. + """ + model = await prisma.models.LlmModel.prisma().find_unique( + where={"id": model_id}, include={"Costs": True} + ) + if not model: + raise ValueError(f"Model with id '{model_id}' not found") + + deleted_slug = model.slug + deleted_display_name = model.displayName + + async with transaction() as tx: + count_result = await tx.query_raw( + """ + SELECT COUNT(*) as count + FROM "AgentNode" + WHERE "constantInput"::jsonb->>'model' = $1 + """, + deleted_slug, + ) + nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0 + + if nodes_to_migrate > 0: + if not replacement_model_slug: + raise ValueError( + f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) " + f"are using it. Please provide a replacement_model_slug to migrate them." + ) + replacement = await tx.llmmodel.find_unique( + where={"slug": replacement_model_slug} + ) + if not replacement: + raise ValueError( + f"Replacement model '{replacement_model_slug}' not found" + ) + if not replacement.isEnabled: + raise ValueError( + f"Replacement model '{replacement_model_slug}' is disabled." + ) + + await tx.execute_raw( + """ + UPDATE "AgentNode" + SET "constantInput" = JSONB_SET( + "constantInput"::jsonb, + '{model}', + to_jsonb($1::text) + ) + WHERE "constantInput"::jsonb->>'model' = $2 + """, + replacement_model_slug, + deleted_slug, + ) + + await tx.llmmodel.delete(where={"id": model_id}) + + return { + "deleted_model_slug": deleted_slug, + "deleted_model_display_name": deleted_display_name, + "replacement_model_slug": replacement_model_slug, + "nodes_migrated": nodes_to_migrate, + } + + +async def list_migrations( + include_reverted: bool = False, +) -> list[dict[str, Any]]: + """List model migrations.""" + where: Any = None if include_reverted else {"isReverted": False} + records = await prisma.models.LlmModelMigration.prisma().find_many( + where=where, + order={"createdAt": "desc"}, + ) + return [ + { + "id": r.id, + "source_model_slug": r.sourceModelSlug, + "target_model_slug": r.targetModelSlug, + "reason": r.reason, + "node_count": r.nodeCount, + "custom_credit_cost": r.customCreditCost, + "is_reverted": r.isReverted, + "reverted_at": r.revertedAt.isoformat() if r.revertedAt else None, + "created_at": r.createdAt.isoformat(), + } + for r in records + ] + + +async def revert_migration( + migration_id: str, + re_enable_source_model: bool = True, +) -> dict[str, Any]: + """Revert a model migration, restoring affected nodes to their original model.""" + migration = await prisma.models.LlmModelMigration.prisma().find_unique( + where={"id": migration_id} + ) + if not migration: + raise ValueError(f"Migration with id '{migration_id}' not found") + + if migration.isReverted: + raise ValueError( + f"Migration '{migration_id}' has already been reverted" + ) + + source_model = await prisma.models.LlmModel.prisma().find_unique( + where={"slug": migration.sourceModelSlug} + ) + if not source_model: + raise ValueError( + f"Source model '{migration.sourceModelSlug}' no longer exists." + ) + + migrated_node_ids: list[str] = ( + migration.migratedNodeIds + if isinstance(migration.migratedNodeIds, list) + else json.loads(migration.migratedNodeIds) # type: ignore + ) + if not migrated_node_ids: + raise ValueError("No nodes to revert in this migration") + + source_model_re_enabled = False + + async with transaction() as tx: + if not source_model.isEnabled and re_enable_source_model: + await tx.llmmodel.update( + where={"id": source_model.id}, + data={"isEnabled": True}, + ) + source_model_re_enabled = True + + node_ids_json = json.dumps(migrated_node_ids) + result = await tx.execute_raw( + """ + UPDATE "AgentNode" + SET "constantInput" = JSONB_SET( + "constantInput"::jsonb, + '{model}', + to_jsonb($1::text) + ) + WHERE id::text IN ( + SELECT jsonb_array_elements_text($2::jsonb) + ) + AND "constantInput"::jsonb->>'model' = $3 + """, + migration.sourceModelSlug, + node_ids_json, + migration.targetModelSlug, + ) + nodes_reverted = result if isinstance(result, int) else 0 + + await tx.llmmodelmigration.update( + where={"id": migration_id}, + data={ + "isReverted": True, + "revertedAt": datetime.now(timezone.utc), + }, + ) + + return { + "migration_id": migration_id, + "source_model_slug": migration.sourceModelSlug, + "target_model_slug": migration.targetModelSlug, + "nodes_reverted": nodes_reverted, + "nodes_already_changed": len(migrated_node_ids) - nodes_reverted, + "source_model_re_enabled": source_model_re_enabled, + } async def refresh_runtime_caches() -> None: """Refresh the LLM registry and clear all related caches.""" - # Refresh the in-memory registry await llm_registry.refresh_llm_registry() - - # TODO: Clear block schema caches when block integration is implemented - # TODO: Publish registry refresh notification to executors