mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
format
This commit is contained in:
@@ -14,6 +14,16 @@ from backend.data.llm_registry.model_types import ModelMetadata
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
# Prisma Json type should always be a dict at runtime
|
||||
return dict(value) if value else {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
@@ -137,7 +147,7 @@ async def refresh_llm_registry() -> None:
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata or {},
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
@@ -159,8 +169,8 @@ async def refresh_llm_registry() -> None:
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=record.capabilities or {},
|
||||
extra_metadata=record.metadata or {},
|
||||
capabilities=_json_to_dict(record.capabilities),
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
provider_display_name=(
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Sequence
|
||||
from typing import Any, Iterable, Sequence, cast
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
@@ -102,6 +102,7 @@ async def list_providers(
|
||||
include_models: Whether to include models for each provider
|
||||
enabled_only: If True, only include enabled models (for public routes)
|
||||
"""
|
||||
include: Any = None
|
||||
if include_models:
|
||||
model_where = {"isEnabled": True} if enabled_only else None
|
||||
include = {
|
||||
@@ -110,8 +111,6 @@ async def list_providers(
|
||||
"where": model_where,
|
||||
}
|
||||
}
|
||||
else:
|
||||
include = None
|
||||
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
|
||||
return [_map_provider(record) for record in records]
|
||||
|
||||
@@ -120,7 +119,7 @@ async def upsert_provider(
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
provider_id: str | None = None,
|
||||
) -> llm_model.LlmProvider:
|
||||
data = {
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
@@ -133,7 +132,7 @@ async def upsert_provider(
|
||||
"supportsParallelTool": request.supports_parallel_tool,
|
||||
"metadata": request.metadata,
|
||||
}
|
||||
include = {"Models": {"include": {"Costs": True, "Creator": True}}}
|
||||
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
|
||||
if provider_id:
|
||||
record = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
@@ -145,6 +144,8 @@ async def upsert_provider(
|
||||
data=data,
|
||||
include=include,
|
||||
)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update provider")
|
||||
return _map_provider(record)
|
||||
|
||||
|
||||
@@ -158,7 +159,7 @@ async def list_models(
|
||||
provider_id: Optional filter by provider ID
|
||||
enabled_only: If True, only return enabled models (for public routes)
|
||||
"""
|
||||
where: dict[str, Any] = {}
|
||||
where: Any = {}
|
||||
if provider_id:
|
||||
where["providerId"] = provider_id
|
||||
if enabled_only:
|
||||
@@ -199,7 +200,7 @@ def _cost_create_payload(
|
||||
async def create_model(
|
||||
request: llm_model.CreateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
data: dict[str, Any] = {
|
||||
data: Any = {
|
||||
"slug": request.slug,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
@@ -226,7 +227,7 @@ async def update_model(
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
# Build scalar field updates (non-relation fields)
|
||||
scalar_data: dict[str, Any] = {}
|
||||
scalar_data: Any = {}
|
||||
if request.display_name is not None:
|
||||
scalar_data["displayName"] = request.display_name
|
||||
if request.description is not None:
|
||||
@@ -265,7 +266,9 @@ async def update_model(
|
||||
cost_payload = _cost_create_payload(request.costs)
|
||||
for cost_item in cost_payload["create"]:
|
||||
cost_item["llmModelId"] = model_id
|
||||
await prisma.models.LlmModelCost.prisma().create(data=cost_item)
|
||||
await prisma.models.LlmModelCost.prisma().create(
|
||||
data=cast(Any, cost_item)
|
||||
)
|
||||
# Fetch the updated record
|
||||
record = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id},
|
||||
@@ -374,15 +377,16 @@ async def toggle_model(
|
||||
|
||||
# Create migration record for revert capability
|
||||
if nodes_migrated > 0:
|
||||
migration_data: Any = {
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data={
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
data=migration_data
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
@@ -393,6 +397,8 @@ async def toggle_model(
|
||||
include={"Costs": True},
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return llm_model.ToggleLlmModelResponse(
|
||||
model=_map_model(record),
|
||||
nodes_migrated=nodes_migrated,
|
||||
@@ -540,7 +546,7 @@ async def list_migrations(
|
||||
Returns:
|
||||
List of LlmModelMigration records
|
||||
"""
|
||||
where = None if include_reverted else {"isReverted": False}
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
@@ -711,7 +717,7 @@ async def upsert_creator(
|
||||
creator_id: str | None = None,
|
||||
) -> llm_model.LlmModelCreator:
|
||||
"""Create or update a model creator."""
|
||||
data = {
|
||||
data: Any = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
@@ -726,6 +732,8 @@ async def upsert_creator(
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
|
||||
if record is None:
|
||||
raise ValueError("Failed to create/update creator")
|
||||
return _map_creator(record)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user