feat(backend): Add model migration system - usage tracking, safe delete, disable with migration, revert

- GET /llm/models/{slug}/usage - count AgentNodes using a model
- DELETE /llm/models/{slug} with optional replacement_model_slug for safe migration
- POST /llm/models/{slug}/toggle with migration support when disabling
- GET /llm/migrations - list model migrations (with include_reverted filter)
- POST /llm/migrations/{id}/revert - revert a migration (restores nodes, re-enables source model)
- Transactional migration: counts nodes, migrates atomically, creates LlmModelMigration audit record
- Ported from original PR #11699's db.py
This commit is contained in:
Bentlybro
2026-04-04 19:36:17 +00:00
parent 445eb173a5
commit 36045c7007
2 changed files with 410 additions and 20 deletions

View File

@@ -216,18 +216,19 @@ async def update_model(
@router.delete(
"/llm/models/{slug:path}",
status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def delete_model(
slug: str,
) -> None:
"""Delete an LLM model.
replacement_model_slug: str | None = None,
) -> dict[str, Any]:
"""Delete an LLM model with optional migration.
Requires admin authentication.
If workflows are using this model and no replacement_model_slug is given,
returns 400 with the node count. Provide replacement_model_slug to migrate
affected nodes before deletion.
"""
try:
# Find model by slug first to get ID
import prisma.models
existing = await prisma.models.LlmModel.prisma().find_unique(
@@ -238,9 +239,15 @@ async def delete_model(
status_code=404, detail=f"Model with slug '{slug}' not found"
)
await db_write.delete_model(model_id=existing.id)
result = await db_write.delete_model(
model_id=existing.id,
replacement_model_slug=replacement_model_slug,
)
await db_write.refresh_runtime_caches()
logger.info(f"Deleted model '{slug}' (id: {existing.id})")
logger.info(
f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)"
)
return result
except ValueError as e:
logger.warning(f"Model deletion validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
@@ -249,6 +256,117 @@ async def delete_model(
raise HTTPException(status_code=500, detail="Failed to delete model")
@router.get(
"/llm/models/{slug:path}/usage",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def get_model_usage(slug: str) -> dict[str, Any]:
"""Get usage count for a model — how many workflow nodes reference it."""
try:
return await db_write.get_model_usage(slug)
except Exception as e:
logger.exception(f"Failed to get model usage: {e}")
raise HTTPException(status_code=500, detail="Failed to get model usage")
@router.post(
"/llm/models/{slug:path}/toggle",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def toggle_model(
slug: str,
request: dict[str, Any],
) -> dict[str, Any]:
"""Toggle a model's enabled status with optional migration when disabling.
Body params:
is_enabled: bool
migrate_to_slug: optional str
migration_reason: optional str
custom_credit_cost: optional int
"""
try:
import prisma.models
existing = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": slug}
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Model with slug '{slug}' not found"
)
result = await db_write.toggle_model_with_migration(
model_id=existing.id,
is_enabled=request.get("is_enabled", True),
migrate_to_slug=request.get("migrate_to_slug"),
migration_reason=request.get("migration_reason"),
custom_credit_cost=request.get("custom_credit_cost"),
)
await db_write.refresh_runtime_caches()
logger.info(
f"Toggled model '{slug}' enabled={request.get('is_enabled')} "
f"(migrated {result['nodes_migrated']} nodes)"
)
return result
except ValueError as e:
logger.warning(f"Model toggle failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to toggle model: {e}")
raise HTTPException(status_code=500, detail="Failed to toggle model")
@router.get(
"/llm/migrations",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def list_migrations(
include_reverted: bool = False,
) -> dict[str, Any]:
"""List model migrations."""
try:
migrations = await db_write.list_migrations(
include_reverted=include_reverted
)
return {"migrations": migrations}
except Exception as e:
logger.exception(f"Failed to list migrations: {e}")
raise HTTPException(
status_code=500, detail="Failed to list migrations"
)
@router.post(
"/llm/migrations/{migration_id}/revert",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> dict[str, Any]:
"""Revert a model migration, restoring affected nodes."""
try:
result = await db_write.revert_migration(
migration_id=migration_id,
re_enable_source_model=re_enable_source_model,
)
await db_write.refresh_runtime_caches()
logger.info(
f"Reverted migration {migration_id}: "
f"{result['nodes_reverted']} nodes restored"
)
return result
except ValueError as e:
logger.warning(f"Migration revert failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to revert migration: {e}")
raise HTTPException(
status_code=500, detail="Failed to revert migration"
)
@router.post(
"/llm/providers",
status_code=status.HTTP_201_CREATED,

View File

@@ -1,11 +1,17 @@
"""Database write operations for LLM registry admin API."""
import json
import logging
from datetime import datetime, timezone
from typing import Any
import prisma
import prisma.models
from backend.data import llm_registry
from backend.data.db import transaction
logger = logging.getLogger(__name__)
def _build_provider_data(
@@ -270,25 +276,291 @@ async def update_model(
return model
async def delete_model(model_id: str) -> bool:
"""Delete an LLM model.
async def get_model_usage(slug: str) -> dict[str, Any]:
"""Get usage count for a model — how many AgentNodes reference it."""
import prisma as prisma_module
Note: This should check if any workflows are using this model first.
For now, we'll allow deletion and rely on FK constraints.
"""
# Check if model exists
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
slug,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return {"model_slug": slug, "node_count": node_count}
async def toggle_model_with_migration(
model_id: str,
is_enabled: bool,
migrate_to_slug: str | None = None,
migration_reason: str | None = None,
custom_credit_cost: int | None = None,
) -> dict[str, Any]:
"""Toggle a model's enabled status, optionally migrating workflows when disabling."""
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
await prisma.models.LlmModel.prisma().delete(where={"id": model_id})
return True
nodes_migrated = 0
migration_id: str | None = None
if not is_enabled and migrate_to_slug:
async with transaction() as tx:
replacement = await tx.llmmodel.find_unique(
where={"slug": migrate_to_slug}
)
if not replacement:
raise ValueError(
f"Replacement model '{migrate_to_slug}' not found"
)
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{migrate_to_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
node_ids_result = await tx.query_raw(
"""
SELECT id
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
FOR UPDATE
""",
model.slug,
)
migrated_node_ids = (
[row["id"] for row in node_ids_result] if node_ids_result else []
)
nodes_migrated = len(migrated_node_ids)
if nodes_migrated > 0:
node_ids_json = json.dumps(migrated_node_ids)
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
""",
migrate_to_slug,
node_ids_json,
)
await tx.llmmodel.update(
where={"id": model_id},
data={"isEnabled": is_enabled},
)
if nodes_migrated > 0:
migration_record = await tx.llmmodelmigration.create(
data={
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
)
migration_id = migration_record.id
else:
await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data={"isEnabled": is_enabled},
)
return {
"nodes_migrated": nodes_migrated,
"migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None,
"migration_id": migration_id,
}
async def delete_model(
model_id: str, replacement_model_slug: str | None = None
) -> dict[str, Any]:
"""Delete an LLM model, optionally migrating affected AgentNodes first.
If workflows are using this model and no replacement is given, raises ValueError.
If replacement is given, atomically migrates all affected nodes then deletes.
"""
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
deleted_slug = model.slug
deleted_display_name = model.displayName
async with transaction() as tx:
count_result = await tx.query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
deleted_slug,
)
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
if nodes_to_migrate > 0:
if not replacement_model_slug:
raise ValueError(
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
f"are using it. Please provide a replacement_model_slug to migrate them."
)
replacement = await tx.llmmodel.find_unique(
where={"slug": replacement_model_slug}
)
if not replacement:
raise ValueError(
f"Replacement model '{replacement_model_slug}' not found"
)
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{replacement_model_slug}' is disabled."
)
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
await tx.llmmodel.delete(where={"id": model_id})
return {
"deleted_model_slug": deleted_slug,
"deleted_model_display_name": deleted_display_name,
"replacement_model_slug": replacement_model_slug,
"nodes_migrated": nodes_to_migrate,
}
async def list_migrations(
include_reverted: bool = False,
) -> list[dict[str, Any]]:
"""List model migrations."""
where: Any = None if include_reverted else {"isReverted": False}
records = await prisma.models.LlmModelMigration.prisma().find_many(
where=where,
order={"createdAt": "desc"},
)
return [
{
"id": r.id,
"source_model_slug": r.sourceModelSlug,
"target_model_slug": r.targetModelSlug,
"reason": r.reason,
"node_count": r.nodeCount,
"custom_credit_cost": r.customCreditCost,
"is_reverted": r.isReverted,
"reverted_at": r.revertedAt.isoformat() if r.revertedAt else None,
"created_at": r.createdAt.isoformat(),
}
for r in records
]
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> dict[str, Any]:
"""Revert a model migration, restoring affected nodes to their original model."""
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
if not migration:
raise ValueError(f"Migration with id '{migration_id}' not found")
if migration.isReverted:
raise ValueError(
f"Migration '{migration_id}' has already been reverted"
)
source_model = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migration.sourceModelSlug}
)
if not source_model:
raise ValueError(
f"Source model '{migration.sourceModelSlug}' no longer exists."
)
migrated_node_ids: list[str] = (
migration.migratedNodeIds
if isinstance(migration.migratedNodeIds, list)
else json.loads(migration.migratedNodeIds) # type: ignore
)
if not migrated_node_ids:
raise ValueError("No nodes to revert in this migration")
source_model_re_enabled = False
async with transaction() as tx:
if not source_model.isEnabled and re_enable_source_model:
await tx.llmmodel.update(
where={"id": source_model.id},
data={"isEnabled": True},
)
source_model_re_enabled = True
node_ids_json = json.dumps(migrated_node_ids)
result = await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
AND "constantInput"::jsonb->>'model' = $3
""",
migration.sourceModelSlug,
node_ids_json,
migration.targetModelSlug,
)
nodes_reverted = result if isinstance(result, int) else 0
await tx.llmmodelmigration.update(
where={"id": migration_id},
data={
"isReverted": True,
"revertedAt": datetime.now(timezone.utc),
},
)
return {
"migration_id": migration_id,
"source_model_slug": migration.sourceModelSlug,
"target_model_slug": migration.targetModelSlug,
"nodes_reverted": nodes_reverted,
"nodes_already_changed": len(migrated_node_ids) - nodes_reverted,
"source_model_re_enabled": source_model_re_enabled,
}
async def refresh_runtime_caches() -> None:
"""Refresh the LLM registry and clear all related caches."""
# Refresh the in-memory registry
await llm_registry.refresh_llm_registry()
# TODO: Clear block schema caches when block integration is implemented
# TODO: Publish registry refresh notification to executors