fix: address CodeRabbit review comments for LLM registry PR

- Add try/except around startup LLM registry refresh (rest_api.py)
- Fix race condition in refresh_llm_costs - build list then swap (block_cost_config.py)
- Add @lru_cache to _get_llm_models() with cache clear on refresh (db.py, llm_routes.py)
- Fix retry using wrong model's max_output_tokens after fallback (llm.py)
- Remove redundant @@index([slug]) from schema.prisma
- Fix operationId collision in openapi.json
This commit is contained in:
Bentlybro
2026-02-27 14:41:23 +00:00
parent 13de0af0b3
commit 5752c413f7
7 changed files with 46 additions and 10 deletions

View File

@@ -47,6 +47,8 @@ async def _refresh_runtime_state() -> None:
logger.info("Cleared v2 builder providers cache")
builder_db._build_cached_search_results.cache_clear()
logger.info("Cleared v2 builder search results cache")
builder_db._get_llm_models.cache_clear()
logger.info("Cleared v2 builder LLM models cache")
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)

View File

@@ -1,6 +1,7 @@
import logging
from dataclasses import dataclass
from difflib import SequenceMatcher
from functools import lru_cache
from typing import Any, Sequence, get_args, get_origin
import prisma
@@ -40,11 +41,17 @@ from .model import (
logger = logging.getLogger(__name__)
def _get_llm_models() -> list[str]:
"""Get LLM model names for search matching from the registry."""
return [
@lru_cache(maxsize=1)
def _get_llm_models() -> tuple[str, ...]:
"""Get LLM model names for search matching from the registry.
Cached to avoid rebuilding on every search call.
Cache is cleared when registry is refreshed via _refresh_runtime_state.
Returns tuple for hashability (required by lru_cache).
"""
return tuple(
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
]
)
MAX_LIBRARY_AGENT_RESULTS = 100

View File

@@ -120,8 +120,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
try:
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
except Exception as e:
logger.warning(f"Failed to refresh LLM registry/costs at startup: {e}")
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.blocks._base import BlockSchema

View File

@@ -378,6 +378,9 @@ class LLMResponse(BaseModel):
prompt_tokens: int
completion_tokens: int
reasoning: Optional[str] = None
resolved_max_output_tokens: Optional[int] = (
None # Max output tokens of resolved model (after fallback)
)
def convert_openai_tool_fmt_to_anthropic(
@@ -553,6 +556,7 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
resolved_max_output_tokens=model_max_output,
)
elif provider == "anthropic":
@@ -633,6 +637,7 @@ async def llm_call(
prompt_tokens=resp.usage.input_tokens,
completion_tokens=resp.usage.output_tokens,
reasoning=reasoning,
resolved_max_output_tokens=model_max_output,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
@@ -658,6 +663,7 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=None,
resolved_max_output_tokens=model_max_output,
)
elif provider == "ollama":
if tools:
@@ -680,6 +686,7 @@ async def llm_call(
prompt_tokens=response.get("prompt_eval_count") or 0,
completion_tokens=response.get("eval_count") or 0,
reasoning=None,
resolved_max_output_tokens=model_max_output,
)
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -722,6 +729,7 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
resolved_max_output_tokens=model_max_output,
)
elif provider == "llama_api":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -764,6 +772,7 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
resolved_max_output_tokens=model_max_output,
)
elif provider == "aiml_api":
client = openai.AsyncOpenAI(
@@ -792,6 +801,7 @@ async def llm_call(
completion.usage.completion_tokens if completion.usage else 0
),
reasoning=None,
resolved_max_output_tokens=model_max_output,
)
elif provider == "v0":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -828,6 +838,7 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
resolved_max_output_tokens=model_max_output,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
@@ -953,6 +964,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
prompt_tokens=0,
completion_tokens=0,
reasoning=None,
resolved_max_output_tokens=None,
),
"get_collision_proof_output_tag_id": lambda *args: "test123456",
},
@@ -1028,6 +1040,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
error_feedback_message = ""
llm_model = input_data.model
# Resolve the model once before retry loop to get correct max_output_tokens
# for retry logic (handles fallback models correctly)
resolved = await resolve_model_for_call(llm_model)
resolved_max_output = resolved.max_output_tokens
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
try:
@@ -1150,7 +1167,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
or "token limit" in str(e).lower()
):
if input_data.max_tokens is None:
input_data.max_tokens = llm_model.max_output_tokens or 4096
# Use resolved model's max_output_tokens (handles fallback correctly)
input_data.max_tokens = resolved_max_output or 4096
input_data.max_tokens = int(input_data.max_tokens * 0.85)
logger.debug(
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"

View File

@@ -163,8 +163,11 @@ async def refresh_llm_costs() -> None:
This function also checks for active model migrations with custom pricing overrides
and applies them to ensure accurate billing.
"""
# Build new costs first, then swap atomically to avoid race condition
# where concurrent readers see an empty list during the await
new_costs = await _build_llm_costs_from_registry()
LLM_COST.clear()
LLM_COST.extend(await _build_llm_costs_from_registry())
LLM_COST.extend(new_costs)
# Initial load will happen after registry is refreshed at startup

View File

@@ -1265,7 +1265,7 @@ model LlmModel {
@@index([providerId, isEnabled])
@@index([creatorId])
@@index([slug])
// Note: slug already has @unique which creates an implicit index
}
model LlmModelCost {
@@ -1330,6 +1330,9 @@ model LlmModelMigration {
@@index([isReverted])
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
////////////// OAUTH PROVIDER TABLES //////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////

View File

@@ -4636,7 +4636,7 @@
"tags": ["v2", "admin", "llm", "llm", "admin"],
"summary": "Get creator details",
"description": "Get details of a specific model creator.",
"operationId": "getV2Get creator details",
"operationId": "getV2GetLlmCreatorDetails",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{