Restrict LLM model and provider listings to enabled items

Updated public LLM model and provider listing endpoints to only return enabled models and providers. Refactored database access functions to support filtering by enabled status, and improved transaction safety for model deletion. Adjusted tests and internal documentation to reflect these changes.
This commit is contained in:
Bentlybro
2025-12-04 15:56:25 +00:00
parent ec705bbbcf
commit a97fdba554
4 changed files with 91 additions and 60 deletions

View File

@@ -191,21 +191,20 @@ def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
def get_llm_model_schema_options() -> list[dict[str, str]]:
"""
Get schema options for LLM model selection dropdown.
Always rebuilds from current registry state to ensure enabled/disabled status is current.
Returns cached schema options that are refreshed when the registry is updated
via refresh_llm_registry() (called on startup and via Redis pub/sub notifications).
"""
# Always rebuild to ensure we have the latest enabled/disabled status
# This is called when generating block schemas, so we need fresh data
_refresh_cached_schema()
return _schema_options
def get_llm_discriminator_mapping() -> dict[str, str]:
"""
Get discriminator mapping for LLM models.
Always rebuilds from current registry state to ensure it's current.
Returns cached discriminator mapping that is refreshed when the registry is updated
via refresh_llm_registry() (called on startup and via Redis pub/sub notifications).
"""
# Always rebuild to ensure we have the latest mapping
_refresh_cached_schema()
return _discriminator_mapping

View File

@@ -57,7 +57,7 @@ def test_list_llm_providers_success(
]
mocker.patch(
"backend.server.v2.admin.llm_routes.get_all_providers",
"backend.server.v2.admin.llm_routes.llm_db.list_providers",
new=AsyncMock(return_value=mock_providers),
)
@@ -102,7 +102,7 @@ def test_list_llm_models_success(
]
mocker.patch(
"backend.server.v2.admin.llm_routes.get_all_models",
"backend.server.v2.admin.llm_routes.llm_db.list_models",
new=AsyncMock(return_value=mock_models),
)
@@ -135,12 +135,12 @@ def test_create_llm_provider_success(
}
mocker.patch(
"backend.server.v2.admin.llm_routes.upsert_provider",
"backend.server.v2.admin.llm_routes.llm_db.upsert_provider",
new=AsyncMock(return_value=mock_provider),
)
mock_notify = mocker.patch(
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
mock_refresh = mocker.patch(
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
@@ -162,8 +162,8 @@ def test_create_llm_provider_success(
assert response_data["name"] == "groq"
assert response_data["display_name"] == "Groq"
# Verify notification was sent
mock_notify.assert_called_once()
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "create_llm_provider_success.json")
@@ -196,12 +196,12 @@ def test_create_llm_model_success(
}
mocker.patch(
"backend.server.v2.admin.llm_routes.create_model",
"backend.server.v2.admin.llm_routes.llm_db.create_model",
new=AsyncMock(return_value=mock_model),
)
mock_notify = mocker.patch(
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
mock_refresh = mocker.patch(
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
@@ -231,8 +231,8 @@ def test_create_llm_model_success(
assert response_data["slug"] == "gpt-4.1-mini"
assert response_data["is_enabled"] is True
# Verify notification was sent
mock_notify.assert_called_once()
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "create_llm_model_success.json")
@@ -265,12 +265,12 @@ def test_update_llm_model_success(
}
mocker.patch(
"backend.server.v2.admin.llm_routes.update_model",
"backend.server.v2.admin.llm_routes.llm_db.update_model",
new=AsyncMock(return_value=mock_model),
)
mock_notify = mocker.patch(
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
mock_refresh = mocker.patch(
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
@@ -288,8 +288,8 @@ def test_update_llm_model_success(
assert response_data["display_name"] == "GPT-4o Updated"
assert response_data["context_window"] == 256000
# Verify notification was sent
mock_notify.assert_called_once()
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "update_llm_model_success.json")
@@ -315,12 +315,12 @@ def test_toggle_llm_model_success(
}
mocker.patch(
"backend.server.v2.admin.llm_routes.toggle_model",
"backend.server.v2.admin.llm_routes.llm_db.toggle_model",
new=AsyncMock(return_value=mock_model),
)
mock_notify = mocker.patch(
"backend.server.v2.admin.llm_routes.notify_llm_registry_refresh",
mock_refresh = mocker.patch(
"backend.server.v2.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
@@ -332,8 +332,8 @@ def test_toggle_llm_model_success(
response_data = response.json()
assert response_data["is_enabled"] is False
# Verify notification was sent
mock_notify.assert_called_once()
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response
configured_snapshot.assert_match(response_data, "toggle_llm_model_success.json")

View File

@@ -4,6 +4,7 @@ from typing import Any, Iterable, Sequence
import prisma.models
from backend.data.db import transaction
from backend.server.v2.llm import model as llm_model
@@ -70,8 +71,21 @@ def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
)
async def list_providers(include_models: bool = True) -> list[llm_model.LlmProvider]:
include = {"Models": {"include": {"Costs": True}}} if include_models else None
async def list_providers(
include_models: bool = True, enabled_only: bool = False
) -> list[llm_model.LlmProvider]:
"""
List all LLM providers.
Args:
include_models: Whether to include models for each provider
enabled_only: If True, only include enabled models (for public routes)
"""
if include_models:
model_where = {"isEnabled": True} if enabled_only else None
include = {"Models": {"include": {"Costs": True}, "where": model_where}}
else:
include = None
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
return [_map_provider(record) for record in records]
@@ -107,10 +121,24 @@ async def upsert_provider(
return _map_provider(record)
async def list_models(provider_id: str | None = None) -> list[llm_model.LlmModel]:
where = {"providerId": provider_id} if provider_id else None
async def list_models(
provider_id: str | None = None, enabled_only: bool = False
) -> list[llm_model.LlmModel]:
"""
List LLM models.
Args:
provider_id: Optional filter by provider ID
enabled_only: If True, only return enabled models (for public routes)
"""
where: dict[str, Any] = {}
if provider_id:
where["providerId"] = provider_id
if enabled_only:
where["isEnabled"] = True
records = await prisma.models.LlmModel.prisma().find_many(
where=where,
where=where if where else None,
include={"Costs": True},
)
return [_map_model(record) for record in records]
@@ -227,12 +255,12 @@ async def delete_model(
"""
Delete a model and migrate all AgentNodes using it to a replacement model.
This performs an atomic operation:
This performs an atomic operation within a database transaction:
1. Validates the model exists
2. Validates the replacement model exists and is enabled
3. Counts affected nodes
4. Migrates all AgentNode.constantInput->model to replacement
5. Deletes the LlmModel record (CASCADE deletes costs)
4. Migrates all AgentNode.constantInput->model to replacement (in transaction)
5. Deletes the LlmModel record (CASCADE deletes costs) (in transaction)
Args:
model_id: UUID of the model to delete
@@ -244,9 +272,7 @@ async def delete_model(
Raises:
ValueError: If model not found, replacement not found, or replacement is disabled
"""
import prisma as prisma_module
# 1. Get the model being deleted
# 1. Get the model being deleted (validation - outside transaction)
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
@@ -256,7 +282,7 @@ async def delete_model(
deleted_slug = model.slug
deleted_display_name = model.displayName
# 2. Validate replacement model exists and is enabled
# 2. Validate replacement model exists and is enabled (validation - outside transaction)
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": replacement_model_slug}
)
@@ -268,7 +294,9 @@ async def delete_model(
f"Please enable it before using it as a replacement."
)
# 3. Count affected nodes
# 3. Count affected nodes (read - outside transaction)
import prisma as prisma_module
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
@@ -279,24 +307,26 @@ async def delete_model(
)
nodes_affected = int(count_result[0]["count"]) if count_result else 0
# 4. Perform migration
if nodes_affected > 0:
await prisma_module.get_client().execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
# 4 & 5. Perform migration and deletion atomically within a transaction
async with transaction() as tx:
# Migrate all AgentNode.constantInput->model to replacement
if nodes_affected > 0:
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
# 5. Delete the model (CASCADE will delete costs automatically)
await prisma.models.LlmModel.prisma().delete(where={"id": model_id})
# Delete the model (CASCADE will delete costs automatically)
await tx.llmmodel.delete(where={"id": model_id})
return llm_model.DeleteLlmModelResponse(
deleted_model_slug=deleted_slug,

View File

@@ -13,11 +13,13 @@ router = fastapi.APIRouter(
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models():
models = await llm_db.list_models()
"""List all enabled LLM models available to users."""
models = await llm_db.list_models(enabled_only=True)
return llm_model.LlmModelsResponse(models=models)
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
providers = await llm_db.list_providers(include_models=True)
"""List all LLM providers with their enabled models."""
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
return llm_model.LlmProvidersResponse(providers=providers)