Make LLM cost refresh async and support overrides

Convert refresh_llm_costs to async and update all callers to await it. Implement async _build_llm_costs_from_registry which queries prisma LlmModelMigration for active migrations with customCreditCost and applies per-model pricing overrides when present (with a safe try/except). Add two SQL migrations: a composite index on LlmModelMigration to optimize override queries and a sync migration to add/remove/update LLM models and their costs. This ensures billing uses migration-provided custom pricing and that registry refreshes correctly await cost recalculation.
This commit is contained in:
Bentlybro
2026-02-12 11:11:01 +00:00
parent 8e6bc5eb48
commit b11d46d246
7 changed files with 128 additions and 10 deletions

View File

@@ -22,7 +22,7 @@ async def _refresh_runtime_state() -> None:
try:
# Refresh registry from database
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
await refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.data.block import BlockSchema

View File

@@ -120,7 +120,7 @@ async def lifespan_context(app: fastapi.FastAPI):
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
await refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.data.block import BlockSchema

View File

@@ -911,7 +911,7 @@ async def initialize_blocks() -> None:
if is_connected():
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
await refresh_llm_costs()
logger.info("LLM registry refreshed during block initialization")
else:
logger.warning(

View File

@@ -1,6 +1,8 @@
import logging
from typing import Type
import prisma.models
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
from backend.blocks.ai_image_generator_block import AIImageGeneratorBlock, ImageGenModel
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
@@ -75,8 +77,40 @@ PROVIDER_CREDENTIALS = {
LLM_COST: list[BlockCost] = []
def _build_llm_costs_from_registry() -> list[BlockCost]:
"""Build BlockCost list from all models in the LLM registry."""
async def _build_llm_costs_from_registry() -> list[BlockCost]:
"""
Build BlockCost list from all models in the LLM registry.
This function checks for active model migrations with customCreditCost overrides.
When a model has been migrated with a custom price, that price is used instead
of the target model's default cost.
"""
# Query active migrations with custom pricing overrides
migration_overrides: dict[str, int] = {}
try:
active_migrations = await prisma.models.LlmModelMigration.prisma().find_many(
where={
"isReverted": False,
"customCreditCost": {"not": None},
}
)
migration_overrides = {
migration.sourceModelSlug: migration.customCreditCost
for migration in active_migrations
if migration.customCreditCost is not None
}
if migration_overrides:
logger.info(
"Found %d active model migrations with custom pricing overrides",
len(migration_overrides),
)
except Exception as exc:
logger.warning(
"Failed to query model migration overrides: %s. Proceeding with default costs.",
exc,
exc_info=True,
)
costs: list[BlockCost] = []
for model in llm_registry.iter_dynamic_models():
for cost in model.costs:
@@ -88,6 +122,18 @@ def _build_llm_costs_from_registry() -> list[BlockCost]:
cost.credential_provider,
)
continue
# Check if this model has a custom cost override from migration
cost_amount = migration_overrides.get(model.slug, cost.credit_cost)
if model.slug in migration_overrides:
logger.debug(
"Applying custom cost override for model %s: %d credits (default: %d)",
model.slug,
cost_amount,
cost.credit_cost,
)
cost_filter = {
"model": model.slug,
"credentials": {
@@ -100,16 +146,21 @@ def _build_llm_costs_from_registry() -> list[BlockCost]:
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter=cost_filter,
cost_amount=cost.credit_cost,
cost_amount=cost_amount,
)
)
return costs
def refresh_llm_costs() -> None:
"""Refresh LLM costs from the registry. All costs now come from the database."""
async def refresh_llm_costs() -> None:
"""
Refresh LLM costs from the registry. All costs now come from the database.
This function also checks for active model migrations with custom pricing overrides
and applies them to ensure accurate billing.
"""
LLM_COST.clear()
LLM_COST.extend(_build_llm_costs_from_registry())
LLM_COST.extend(await _build_llm_costs_from_registry())
# Initial load will happen after registry is refreshed at startup

View File

@@ -47,7 +47,7 @@ async def refresh_registry_on_notification() -> None:
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
await refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()

View File

@@ -0,0 +1,6 @@
-- Add composite index on LlmModelMigration for optimized active migration queries
-- This index improves performance when querying for non-reverted migrations by model slug
-- Used by the billing system to apply customCreditCost overrides
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");

View File

@@ -0,0 +1,61 @@
-- Sync LLM models with latest dev branch changes
-- This migration adds new models and removes deprecated ones
-- Remove models that were deleted from dev
DELETE FROM "LlmModelCost" WHERE "llmModelId" IN (
SELECT "id" FROM "LlmModel" WHERE "slug" IN ('o3', 'o3-mini', 'claude-3-7-sonnet-20250219')
);
DELETE FROM "LlmModel" WHERE "slug" IN ('o3', 'o3-mini', 'claude-3-7-sonnet-20250219');
-- Add new models from dev
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
model_slug,
model_display_name,
NULL,
p."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb
FROM (VALUES
-- New OpenAI model
('gpt-5.2-2025-12-11', 'GPT 5.2', 'openai', 400000, 128000),
-- New Anthropic model
('claude-opus-4-6', 'Claude 4.6 Opus', 'anthropic', 200000, 64000)
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
ON CONFLICT ("slug") DO NOTHING;
-- Add costs for new models
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM (VALUES
-- New model costs (estimate based on similar models)
('gpt-5.2-2025-12-11', 5), -- Similar to GPT 5.1
('claude-opus-4-6', 21) -- Similar to other Opus 4.x models
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId"
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;