mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Restrict LLM model and provider listings to enabled items
Updated public LLM model and provider listing endpoints to only return enabled models and providers. Refactored database access functions to support filtering by enabled status, and improved transaction safety for model deletion. Adjusted tests and internal documentation to reflect these changes.
This commit is contained in:
@@ -191,21 +191,20 @@ def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
|
||||
def get_llm_model_schema_options() -> list[dict[str, str]]:
|
||||
"""
|
||||
Get schema options for LLM model selection dropdown.
|
||||
Always rebuilds from current registry state to ensure enabled/disabled status is current.
|
||||
|
||||
Returns cached schema options that are refreshed when the registry is updated
|
||||
via refresh_llm_registry() (called on startup and via Redis pub/sub notifications).
|
||||
"""
|
||||
# Always rebuild to ensure we have the latest enabled/disabled status
|
||||
# This is called when generating block schemas, so we need fresh data
|
||||
_refresh_cached_schema()
|
||||
return _schema_options
|
||||
|
||||
|
||||
def get_llm_discriminator_mapping() -> dict[str, str]:
|
||||
"""
|
||||
Get discriminator mapping for LLM models.
|
||||
Always rebuilds from current registry state to ensure it's current.
|
||||
|
||||
Returns cached discriminator mapping that is refreshed when the registry is updated
|
||||
via refresh_llm_registry() (called on startup and via Redis pub/sub notifications).
|
||||
"""
|
||||
# Always rebuild to ensure we have the latest mapping
|
||||
_refresh_cached_schema()
|
||||
return _discriminator_mapping
|
||||
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ def test_list_llm_providers_success(
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.get_all_providers",
|
||||
"backend.server.v2.admin.llm_routes.llm_db.list_providers",
|
||||
new=AsyncMock(return_value=mock_providers),
|
||||
)
|
||||
|
||||
@@ -102,7 +102,7 @@ def test_list_llm_models_success(
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.get_all_models",
|
||||
"backend.server.v2.admin.llm_routes.llm_db.list_models",
|
||||
new=AsyncMock(return_value=mock_models),
|
||||
)
|
||||
|
||||
@@ -135,12 +135,12 @@ def test_create_llm_provider_success(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.upsert_provider",
|
||||
"backend.server.v2.admin.llm_routes.llm_db.upsert_provider",
|
||||
new=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
|
||||
mock_notify = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
@@ -162,8 +162,8 @@ def test_create_llm_provider_success(
|
||||
assert response_data["name"] == "groq"
|
||||
assert response_data["display_name"] == "Groq"
|
||||
|
||||
# Verify notification was sent
|
||||
mock_notify.assert_called_once()
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(response_data, "create_llm_provider_success.json")
|
||||
@@ -196,12 +196,12 @@ def test_create_llm_model_success(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.create_model",
|
||||
"backend.server.v2.admin.llm_routes.llm_db.create_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_notify = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
@@ -231,8 +231,8 @@ def test_create_llm_model_success(
|
||||
assert response_data["slug"] == "gpt-4.1-mini"
|
||||
assert response_data["is_enabled"] is True
|
||||
|
||||
# Verify notification was sent
|
||||
mock_notify.assert_called_once()
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(response_data, "create_llm_model_success.json")
|
||||
@@ -265,12 +265,12 @@ def test_update_llm_model_success(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.update_model",
|
||||
"backend.server.v2.admin.llm_routes.llm_db.update_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_notify = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
@@ -288,8 +288,8 @@ def test_update_llm_model_success(
|
||||
assert response_data["display_name"] == "GPT-4o Updated"
|
||||
assert response_data["context_window"] == 256000
|
||||
|
||||
# Verify notification was sent
|
||||
mock_notify.assert_called_once()
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(response_data, "update_llm_model_success.json")
|
||||
@@ -315,12 +315,12 @@ def test_toggle_llm_model_success(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.toggle_model",
|
||||
"backend.server.v2.admin.llm_routes.llm_db.toggle_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_notify = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
@@ -332,8 +332,8 @@ def test_toggle_llm_model_success(
|
||||
response_data = response.json()
|
||||
assert response_data["is_enabled"] is False
|
||||
|
||||
# Verify notification was sent
|
||||
mock_notify.assert_called_once()
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(response_data, "toggle_llm_model_success.json")
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Iterable, Sequence
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
|
||||
@@ -70,8 +71,21 @@ def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
|
||||
)
|
||||
|
||||
|
||||
async def list_providers(include_models: bool = True) -> list[llm_model.LlmProvider]:
|
||||
include = {"Models": {"include": {"Costs": True}}} if include_models else None
|
||||
async def list_providers(
|
||||
include_models: bool = True, enabled_only: bool = False
|
||||
) -> list[llm_model.LlmProvider]:
|
||||
"""
|
||||
List all LLM providers.
|
||||
|
||||
Args:
|
||||
include_models: Whether to include models for each provider
|
||||
enabled_only: If True, only include enabled models (for public routes)
|
||||
"""
|
||||
if include_models:
|
||||
model_where = {"isEnabled": True} if enabled_only else None
|
||||
include = {"Models": {"include": {"Costs": True}, "where": model_where}}
|
||||
else:
|
||||
include = None
|
||||
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
|
||||
return [_map_provider(record) for record in records]
|
||||
|
||||
@@ -107,10 +121,24 @@ async def upsert_provider(
|
||||
return _map_provider(record)
|
||||
|
||||
|
||||
async def list_models(provider_id: str | None = None) -> list[llm_model.LlmModel]:
|
||||
where = {"providerId": provider_id} if provider_id else None
|
||||
async def list_models(
|
||||
provider_id: str | None = None, enabled_only: bool = False
|
||||
) -> list[llm_model.LlmModel]:
|
||||
"""
|
||||
List LLM models.
|
||||
|
||||
Args:
|
||||
provider_id: Optional filter by provider ID
|
||||
enabled_only: If True, only return enabled models (for public routes)
|
||||
"""
|
||||
where: dict[str, Any] = {}
|
||||
if provider_id:
|
||||
where["providerId"] = provider_id
|
||||
if enabled_only:
|
||||
where["isEnabled"] = True
|
||||
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where,
|
||||
where=where if where else None,
|
||||
include={"Costs": True},
|
||||
)
|
||||
return [_map_model(record) for record in records]
|
||||
@@ -227,12 +255,12 @@ async def delete_model(
|
||||
"""
|
||||
Delete a model and migrate all AgentNodes using it to a replacement model.
|
||||
|
||||
This performs an atomic operation:
|
||||
This performs an atomic operation within a database transaction:
|
||||
1. Validates the model exists
|
||||
2. Validates the replacement model exists and is enabled
|
||||
3. Counts affected nodes
|
||||
4. Migrates all AgentNode.constantInput->model to replacement
|
||||
5. Deletes the LlmModel record (CASCADE deletes costs)
|
||||
4. Migrates all AgentNode.constantInput->model to replacement (in transaction)
|
||||
5. Deletes the LlmModel record (CASCADE deletes costs) (in transaction)
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model to delete
|
||||
@@ -244,9 +272,7 @@ async def delete_model(
|
||||
Raises:
|
||||
ValueError: If model not found, replacement not found, or replacement is disabled
|
||||
"""
|
||||
import prisma as prisma_module
|
||||
|
||||
# 1. Get the model being deleted
|
||||
# 1. Get the model being deleted (validation - outside transaction)
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
@@ -256,7 +282,7 @@ async def delete_model(
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
# 2. Validate replacement model exists and is enabled
|
||||
# 2. Validate replacement model exists and is enabled (validation - outside transaction)
|
||||
replacement = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
@@ -268,7 +294,9 @@ async def delete_model(
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
# 3. Count affected nodes
|
||||
# 3. Count affected nodes (read - outside transaction)
|
||||
import prisma as prisma_module
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
@@ -279,24 +307,26 @@ async def delete_model(
|
||||
)
|
||||
nodes_affected = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
# 4. Perform migration
|
||||
if nodes_affected > 0:
|
||||
await prisma_module.get_client().execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
# 4 & 5. Perform migration and deletion atomically within a transaction
|
||||
async with transaction() as tx:
|
||||
# Migrate all AgentNode.constantInput->model to replacement
|
||||
if nodes_affected > 0:
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
# 5. Delete the model (CASCADE will delete costs automatically)
|
||||
await prisma.models.LlmModel.prisma().delete(where={"id": model_id})
|
||||
# Delete the model (CASCADE will delete costs automatically)
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
return llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug=deleted_slug,
|
||||
|
||||
@@ -13,11 +13,13 @@ router = fastapi.APIRouter(
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models():
|
||||
models = await llm_db.list_models()
|
||||
"""List all enabled LLM models available to users."""
|
||||
models = await llm_db.list_models(enabled_only=True)
|
||||
return llm_model.LlmModelsResponse(models=models)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
providers = await llm_db.list_providers(include_models=True)
|
||||
"""List all LLM providers with their enabled models."""
|
||||
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
Reference in New Issue
Block a user