This commit is contained in:
Bentlybro
2026-01-06 13:40:16 +00:00
parent 90ae75d475
commit 63869fe710
2 changed files with 40 additions and 22 deletions

View File

@@ -14,6 +14,16 @@ from backend.data.llm_registry.model_types import ModelMetadata
logger = logging.getLogger(__name__)
def _json_to_dict(value: Any) -> dict[str, Any]:
"""Convert Prisma Json type to dict, with fallback to empty dict."""
if value is None:
return {}
if isinstance(value, dict):
return value
# Prisma Json type should always be a dict at runtime
return dict(value) if value else {}
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
@@ -137,7 +147,7 @@ async def refresh_llm_registry() -> None:
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=cost.metadata or {},
metadata=_json_to_dict(cost.metadata),
)
for cost in (record.Costs or [])
)
@@ -159,8 +169,8 @@ async def refresh_llm_registry() -> None:
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=record.capabilities or {},
extra_metadata=record.metadata or {},
capabilities=_json_to_dict(record.capabilities),
extra_metadata=_json_to_dict(record.metadata),
provider_display_name=(
record.Provider.displayName
if record.Provider

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Iterable, Sequence
from typing import Any, Iterable, Sequence, cast
import prisma
import prisma.models
@@ -102,6 +102,7 @@ async def list_providers(
include_models: Whether to include models for each provider
enabled_only: If True, only include enabled models (for public routes)
"""
include: Any = None
if include_models:
model_where = {"isEnabled": True} if enabled_only else None
include = {
@@ -110,8 +111,6 @@ async def list_providers(
"where": model_where,
}
}
else:
include = None
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
return [_map_provider(record) for record in records]
@@ -120,7 +119,7 @@ async def upsert_provider(
request: llm_model.UpsertLlmProviderRequest,
provider_id: str | None = None,
) -> llm_model.LlmProvider:
data = {
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
@@ -133,7 +132,7 @@ async def upsert_provider(
"supportsParallelTool": request.supports_parallel_tool,
"metadata": request.metadata,
}
include = {"Models": {"include": {"Costs": True, "Creator": True}}}
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
if provider_id:
record = await prisma.models.LlmProvider.prisma().update(
where={"id": provider_id},
@@ -145,6 +144,8 @@ async def upsert_provider(
data=data,
include=include,
)
if record is None:
raise ValueError("Failed to create/update provider")
return _map_provider(record)
@@ -158,7 +159,7 @@ async def list_models(
provider_id: Optional filter by provider ID
enabled_only: If True, only return enabled models (for public routes)
"""
where: dict[str, Any] = {}
where: Any = {}
if provider_id:
where["providerId"] = provider_id
if enabled_only:
@@ -199,7 +200,7 @@ def _cost_create_payload(
async def create_model(
request: llm_model.CreateLlmModelRequest,
) -> llm_model.LlmModel:
data: dict[str, Any] = {
data: Any = {
"slug": request.slug,
"displayName": request.display_name,
"description": request.description,
@@ -226,7 +227,7 @@ async def update_model(
request: llm_model.UpdateLlmModelRequest,
) -> llm_model.LlmModel:
# Build scalar field updates (non-relation fields)
scalar_data: dict[str, Any] = {}
scalar_data: Any = {}
if request.display_name is not None:
scalar_data["displayName"] = request.display_name
if request.description is not None:
@@ -265,7 +266,9 @@ async def update_model(
cost_payload = _cost_create_payload(request.costs)
for cost_item in cost_payload["create"]:
cost_item["llmModelId"] = model_id
await prisma.models.LlmModelCost.prisma().create(data=cost_item)
await prisma.models.LlmModelCost.prisma().create(
data=cast(Any, cost_item)
)
# Fetch the updated record
record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
@@ -374,15 +377,16 @@ async def toggle_model(
# Create migration record for revert capability
if nodes_migrated > 0:
migration_data: Any = {
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
migration_record = await tx.llmmodelmigration.create(
data={
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
data=migration_data
)
migration_id = migration_record.id
else:
@@ -393,6 +397,8 @@ async def toggle_model(
include={"Costs": True},
)
if record is None:
raise ValueError(f"Model with id '{model_id}' not found")
return llm_model.ToggleLlmModelResponse(
model=_map_model(record),
nodes_migrated=nodes_migrated,
@@ -540,7 +546,7 @@ async def list_migrations(
Returns:
List of LlmModelMigration records
"""
where = None if include_reverted else {"isReverted": False}
where: Any = None if include_reverted else {"isReverted": False}
records = await prisma.models.LlmModelMigration.prisma().find_many(
where=where,
order={"createdAt": "desc"},
@@ -711,7 +717,7 @@ async def upsert_creator(
creator_id: str | None = None,
) -> llm_model.LlmModelCreator:
"""Create or update a model creator."""
data = {
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
@@ -726,6 +732,8 @@ async def upsert_creator(
)
else:
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
if record is None:
raise ValueError("Failed to create/update creator")
return _map_creator(record)