mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Refactor LLM model registry to use database
Migrates LLM model metadata and cost configuration from static code to a dynamic database-driven registry. Adds new backend modules for LLM registry and model types, updates block and cost configuration logic to fetch model info and costs from the database, and ensures block schemas and UI options reflect enabled/disabled models. This enables dynamic management of LLM models and costs via the admin UI and database migrations.
This commit is contained in:
@@ -9,6 +9,7 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
llm_model_schema_extra,
|
||||
)
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
@@ -52,6 +53,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
from typing import Any, Iterable, List, Literal, Optional
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
@@ -22,6 +22,8 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.llm_model_types import ModelMetadata
|
||||
from backend.data import llm_registry
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -66,19 +68,24 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
def AICredentialsField() -> AICredentials:
|
||||
"""
|
||||
Returns a CredentialsField for LLM providers.
|
||||
The discriminator_mapping will be refreshed when the schema is generated
|
||||
if it's empty, ensuring the LLM registry is loaded.
|
||||
"""
|
||||
# Get the mapping now - it may be empty initially, but will be refreshed
|
||||
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
||||
mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
discriminator="model",
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
},
|
||||
discriminator_mapping=mapping, # May be empty initially, refreshed later
|
||||
)
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
def llm_model_schema_extra() -> dict[str, Any]:
|
||||
return {"options": llm_registry.get_llm_model_schema_options()}
|
||||
|
||||
|
||||
class LlmModelMeta(EnumMeta):
|
||||
@@ -170,9 +177,21 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
V0_1_5_LG = "v0-1.5-lg"
|
||||
V0_1_0_MD = "v0-1.0-md"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if isinstance(value, str):
|
||||
pseudo_member = str.__new__(cls, value)
|
||||
pseudo_member._name_ = value.upper().replace("-", "_")
|
||||
pseudo_member._value_ = value
|
||||
return pseudo_member
|
||||
return super()._missing_(value)
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[self]
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
raise ValueError(f"Missing metadata for model: {self.value}. Model not found in LLM registry.")
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
@@ -187,125 +206,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
return self.metadata.max_output_tokens
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
|
||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, 16384), # gpt-4o-2024-08-06
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-opus-4-5-20251101
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096
|
||||
), # claude-3-haiku-20240307
|
||||
# https://docs.aimlapi.com/api-overview/model-database/text-models
|
||||
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata("aiml_api", 32000, 8000),
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata("aiml_api", 128000, 40000),
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata("aiml_api", 128000, None),
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata("aiml_api", 131000, 2000),
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata("aiml_api", 128000, None),
|
||||
# https://console.groq.com/docs/models
|
||||
LlmModel.LLAMA3_3_70B: ModelMetadata("groq", 128000, 32768),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 128000, 8192),
|
||||
# https://ollama.com/library
|
||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata("ollama", 8192, None),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192, None),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, None),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, None),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata("open_router", 1048576, 65535),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata("open_router", 1048576, 65535),
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata("open_router", 1048576, 8192),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata("open_router", 1048576, 8192),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata("open_router", 128000, 4096),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata("open_router", 64000, 2048),
|
||||
LlmModel.DEEPSEEK_R1_0528: ModelMetadata("open_router", 163840, 163840),
|
||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata("open_router", 127000, 8000),
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata("open_router", 200000, 8000),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
16000,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||
"open_router", 131000, 4096
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||
"open_router", 12288, 12288
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata("open_router", 131072, 32768),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
|
||||
LlmModel.GROK_4: ModelMetadata("open_router", 256000, 256000),
|
||||
LlmModel.GROK_4_FAST: ModelMetadata("open_router", 2000000, 30000),
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata("open_router", 2000000, 30000),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata("open_router", 256000, 10000),
|
||||
LlmModel.KIMI_K2: ModelMetadata("open_router", 131000, 131000),
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata("open_router", 262144, 262144),
|
||||
LlmModel.QWEN3_CODER: ModelMetadata("open_router", 262144, 262144),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata("llama_api", 128000, 4028),
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
|
||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
# MODEL_METADATA removed - all models now come from the database via llm_registry
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -434,8 +335,81 @@ async def llm_call(
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
# Get model metadata - try cache first, then fallback to async lookup
|
||||
# Also check if the model is enabled
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
|
||||
# Check if model is enabled - get from registry
|
||||
from backend.data.llm_registry import _dynamic_models
|
||||
if llm_model.value in _dynamic_models:
|
||||
model_info = _dynamic_models[llm_model.value]
|
||||
if not model_info.is_enabled:
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' is disabled."
|
||||
)
|
||||
except ValueError as e:
|
||||
# Re-raise if it's our disabled model error
|
||||
if "is disabled" in str(e):
|
||||
raise
|
||||
# Model not in cache - try refreshing the registry once if we have DB access
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Model %s not found in registry cache",
|
||||
llm_model.value,
|
||||
)
|
||||
|
||||
# Try refreshing the registry if we have database access
|
||||
from backend.data.db import is_connected
|
||||
if is_connected():
|
||||
try:
|
||||
logger.info("Refreshing LLM registry and retrying lookup for %s", llm_model.value)
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Try again after refresh
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
|
||||
# Check if model is enabled after refresh
|
||||
from backend.data.llm_registry import _dynamic_models
|
||||
if llm_model.value in _dynamic_models:
|
||||
model_info = _dynamic_models[llm_model.value]
|
||||
if not model_info.is_enabled:
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' is disabled. "
|
||||
"Please enable it in the LLM registry via the admin UI to use this model."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Successfully loaded model %s metadata after registry refresh",
|
||||
llm_model.value,
|
||||
)
|
||||
except ValueError as ve:
|
||||
# Re-raise if it's our disabled model error
|
||||
if "is disabled" in str(ve):
|
||||
raise
|
||||
# Still not found after refresh
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry after refresh. "
|
||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
||||
)
|
||||
except Exception as refresh_exc:
|
||||
logger.error("Failed to refresh LLM registry: %s", refresh_exc, exc_info=True)
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
||||
) from refresh_exc
|
||||
else:
|
||||
# No DB access (e.g., in executor without direct DB connection)
|
||||
# The registry should have been loaded on startup
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry cache. "
|
||||
"The registry may need to be refreshed. Please contact support or try again later."
|
||||
) from e
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
@@ -446,7 +420,7 @@ async def llm_call(
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
# model_max_output already set above
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
@@ -793,6 +767,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
@@ -1224,6 +1199,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
@@ -1319,6 +1295,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for summarizing the text.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
title="Focus",
|
||||
@@ -1536,6 +1513,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for the conversation.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -1638,6 +1616,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_retries: int = SchemaField(
|
||||
|
||||
@@ -10,12 +10,12 @@ import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.data import llm_registry
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_metadata = self.metadata
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
@@ -107,19 +107,23 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
return self.metadata.provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
# Fallback to LlmModel enum if registry lookup fails
|
||||
return LlmModel(self.value).metadata
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
return self.metadata.context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
return self.metadata.max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
|
||||
@@ -140,37 +140,234 @@ class BlockInfo(BaseModel):
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
|
||||
|
||||
@classmethod
|
||||
def clear_schema_cache(cls) -> None:
|
||||
"""Clear the cached JSON schema for this class."""
|
||||
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||
cls.cached_jsonschema = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def clear_all_schema_caches() -> None:
|
||||
"""Clear cached JSON schemas for all BlockSchema subclasses."""
|
||||
def clear_recursive(cls: type) -> None:
|
||||
"""Recursively clear cache for class and all subclasses."""
|
||||
if hasattr(cls, 'clear_schema_cache'):
|
||||
cls.clear_schema_cache()
|
||||
for subclass in cls.__subclasses__():
|
||||
clear_recursive(subclass)
|
||||
|
||||
clear_recursive(BlockSchema)
|
||||
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
if cls.cached_jsonschema:
|
||||
return cls.cached_jsonschema
|
||||
# Generate schema if not cached
|
||||
if not cls.cached_jsonschema:
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
||||
keys = {"allOf", "anyOf", "oneOf"}
|
||||
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
||||
if one_key:
|
||||
obj.update(obj[one_key][0])
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
|
||||
return {
|
||||
key: ref_to_dict(value)
|
||||
for key, value in obj.items()
|
||||
if not key.startswith("$") and key != one_key
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
|
||||
return obj
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
return obj
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
|
||||
# Always post-process to ensure discriminator is present for multi-provider credentials fields
|
||||
# and refresh LLM model options to get latest enabled/disabled status
|
||||
# This handles cases where the schema was generated before the registry was loaded or updated
|
||||
# Note: We mutate the cached schema directly, which is safe since post-processing is idempotent
|
||||
# IMPORTANT: We always refresh options even if schema was cached, to ensure disabled models are excluded
|
||||
cls._ensure_discriminator_in_schema(cls.cached_jsonschema, cls)
|
||||
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@staticmethod
|
||||
def _ensure_discriminator_in_schema(schema: dict[str, Any], model_class: type | None = None) -> None:
|
||||
"""Ensure discriminator is present in multi-provider credentials fields and refresh LLM model options."""
|
||||
properties = schema.get("properties", {})
|
||||
for field_name, field_schema in properties.items():
|
||||
if not isinstance(field_schema, dict):
|
||||
continue
|
||||
|
||||
# Check if this is an LLM model field by checking the field definition
|
||||
is_llm_model_field = False
|
||||
if model_class and hasattr(model_class, "model_fields") and field_name in model_class.model_fields:
|
||||
try:
|
||||
field_info = model_class.model_fields[field_name]
|
||||
# Check if json_schema_extra has "options" (set by llm_model_schema_extra)
|
||||
if hasattr(field_info, "json_schema_extra") and isinstance(field_info.json_schema_extra, dict):
|
||||
if "options" in field_info.json_schema_extra:
|
||||
is_llm_model_field = True
|
||||
# Also check if the field type is LlmModel
|
||||
if not is_llm_model_field and hasattr(field_info, "annotation"):
|
||||
from backend.blocks.llm import LlmModel
|
||||
from typing import get_origin, get_args
|
||||
annotation = field_info.annotation
|
||||
if annotation == LlmModel:
|
||||
is_llm_model_field = True
|
||||
else:
|
||||
# Check for Optional[LlmModel] or Union types
|
||||
origin = get_origin(annotation)
|
||||
if origin:
|
||||
args = get_args(annotation)
|
||||
if LlmModel in args:
|
||||
is_llm_model_field = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Only refresh LLM model options for LLM model fields
|
||||
# This prevents filtering other enum fields that aren't LLM models
|
||||
if is_llm_model_field:
|
||||
def refresh_options_in_schema(schema_part: dict[str, Any], path: str = "") -> bool:
|
||||
"""Recursively refresh options in schema part. Returns True if options were found and refreshed."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check for "options" key (used by frontend for select dropdowns)
|
||||
has_options = "options" in schema_part and isinstance(schema_part.get("options"), list)
|
||||
# Check for "enum" key (Pydantic generates this for Enum fields)
|
||||
has_enum = "enum" in schema_part and isinstance(schema_part.get("enum"), list)
|
||||
|
||||
if has_options or has_enum:
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
# Always refresh options from registry to get latest enabled/disabled status
|
||||
fresh_options = llm_registry.get_llm_model_schema_options()
|
||||
if fresh_options:
|
||||
# Get enabled model slugs from fresh options
|
||||
enabled_slugs = {opt.get("value") for opt in fresh_options if isinstance(opt, dict) and "value" in opt}
|
||||
|
||||
# Update "options" if present
|
||||
if has_options:
|
||||
old_count = len(schema_part["options"])
|
||||
old_slugs = {opt.get("value") for opt in schema_part["options"] if isinstance(opt, dict) and "value" in opt}
|
||||
schema_part["options"] = fresh_options
|
||||
new_count = len(fresh_options)
|
||||
new_slugs = enabled_slugs
|
||||
|
||||
# Log if there's a difference (models added/removed)
|
||||
if old_count != new_count or old_slugs != new_slugs:
|
||||
removed = old_slugs - new_slugs
|
||||
added = new_slugs - old_slugs
|
||||
if removed or added:
|
||||
logger.info(
|
||||
"Refreshed LLM model options for field %s%s: %d -> %d models. "
|
||||
"Removed: %s, Added: %s",
|
||||
field_name, f".{path}" if path else "", old_count, new_count, removed, added
|
||||
)
|
||||
|
||||
# Update "enum" if present - filter to only enabled models
|
||||
if has_enum:
|
||||
old_enum = schema_part.get("enum", [])
|
||||
# Filter enum values to only include enabled models
|
||||
filtered_enum = [val for val in old_enum if val in enabled_slugs]
|
||||
schema_part["enum"] = filtered_enum
|
||||
|
||||
if len(old_enum) != len(filtered_enum):
|
||||
removed_enum = set(old_enum) - enabled_slugs
|
||||
logger.info(
|
||||
"Filtered LLM model enum for field %s%s: %d -> %d models. "
|
||||
"Removed disabled: %s",
|
||||
field_name, f".{path}" if path else "", len(old_enum), len(filtered_enum), removed_enum
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to refresh LLM model options for field %s%s: %s", field_name, f".{path}" if path else "", e)
|
||||
|
||||
# Check nested structures
|
||||
for key in ["anyOf", "oneOf", "allOf"]:
|
||||
if key in schema_part and isinstance(schema_part[key], list):
|
||||
for idx, item in enumerate(schema_part[key]):
|
||||
if isinstance(item, dict) and refresh_options_in_schema(item, f"{path}.{key}[{idx}]" if path else f"{key}[{idx}]"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
refresh_options_in_schema(field_schema)
|
||||
|
||||
# Check if this is a credentials field - look for credentials_provider or credentials_types
|
||||
has_credentials_provider = "credentials_provider" in field_schema
|
||||
has_credentials_types = "credentials_types" in field_schema
|
||||
|
||||
if not (has_credentials_provider or has_credentials_types):
|
||||
continue
|
||||
|
||||
# This is a credentials field
|
||||
providers = field_schema.get("credentials_provider", [])
|
||||
|
||||
# If providers not in field schema yet, try to get from model class
|
||||
if not providers and model_class and hasattr(model_class, "model_fields"):
|
||||
try:
|
||||
if field_name in model_class.model_fields:
|
||||
field_info = model_class.model_fields[field_name]
|
||||
# Check if this is a CredentialsMetaInput field
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
if (hasattr(field_info, "annotation") and
|
||||
inspect.isclass(field_info.annotation) and
|
||||
issubclass(get_origin(field_info.annotation) or field_info.annotation, CredentialsMetaInput)):
|
||||
# Get providers from the annotation
|
||||
providers_list = CredentialsMetaInput.allowed_providers.__func__(field_info.annotation)
|
||||
if providers_list:
|
||||
providers = list(providers_list)
|
||||
field_schema["credentials_provider"] = providers
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check if this is a multi-provider field
|
||||
if isinstance(providers, list) and len(providers) > 1:
|
||||
# Multi-provider field - ensure discriminator is set
|
||||
if "discriminator" not in field_schema:
|
||||
# Try to get discriminator from model field definition
|
||||
discriminator_found = False
|
||||
if model_class and hasattr(model_class, "model_fields") and field_name in model_class.model_fields:
|
||||
try:
|
||||
field_info = model_class.model_fields[field_name]
|
||||
if hasattr(field_info, "json_schema_extra") and isinstance(field_info.json_schema_extra, dict):
|
||||
discriminator = field_info.json_schema_extra.get("discriminator")
|
||||
if discriminator:
|
||||
field_schema["discriminator"] = discriminator
|
||||
discriminator_found = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If not found, check if this looks like an LLM field and default to "model"
|
||||
if not discriminator_found:
|
||||
llm_providers = {"openai", "anthropic", "groq", "open_router", "llama_api", "aiml_api", "v0", "ollama"}
|
||||
if any(p in llm_providers for p in providers):
|
||||
field_schema["discriminator"] = "model"
|
||||
|
||||
# If discriminator is "model", ensure discriminator_mapping is populated
|
||||
if field_schema.get("discriminator") == "model":
|
||||
mapping = field_schema.get("discriminator_mapping")
|
||||
# If mapping is empty, missing, or None, refresh from registry
|
||||
if not mapping or (isinstance(mapping, dict) and len(mapping) == 0):
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
refreshed_mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
if refreshed_mapping:
|
||||
field_schema["discriminator_mapping"] = refreshed_mapping
|
||||
else:
|
||||
# Ensure at least an empty dict is present
|
||||
field_schema["discriminator_mapping"] = {}
|
||||
except Exception:
|
||||
if "discriminator_mapping" not in field_schema:
|
||||
field_schema["discriminator_mapping"] = {}
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
@@ -733,6 +930,23 @@ def is_block_auth_configured(
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# This ensures the registry cache is populated even in executor context
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
|
||||
# Only refresh if we have DB access (check if Prisma is connected)
|
||||
from backend.data.db import is_connected
|
||||
if is_connected():
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
logger.info("LLM registry refreshed during block initialization")
|
||||
else:
|
||||
logger.warning("Prisma not connected, skipping LLM registry refresh during block initialization")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to refresh LLM registry during block initialization: %s", exc)
|
||||
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
|
||||
@@ -23,19 +24,18 @@ from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
AIListGeneratorBlock,
|
||||
AIStructuredResponseGeneratorBlock,
|
||||
AITextGeneratorBlock,
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
@@ -55,209 +55,63 @@ from backend.integrations.credentials_store import (
|
||||
v0_credentials,
|
||||
)
|
||||
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT41_MINI: 1,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
LlmModel.GROK_CODE_FAST_1: 1,
|
||||
LlmModel.KIMI_K2: 1,
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
||||
LlmModel.QWEN3_CODER: 9,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: 1,
|
||||
LlmModel.V0_1_5_LG: 2,
|
||||
LlmModel.V0_1_0_MD: 1,
|
||||
PROVIDER_CREDENTIALS = {
|
||||
"openai": openai_credentials,
|
||||
"anthropic": anthropic_credentials,
|
||||
"groq": groq_credentials,
|
||||
"open_router": open_router_credentials,
|
||||
"llama_api": llama_api_credentials,
|
||||
"aiml_api": aiml_api_credentials,
|
||||
"v0": v0_credentials,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_COST:
|
||||
raise ValueError(f"Missing MODEL_COST for model: {model}")
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
# All LLM costs now come from the database via llm_registry
|
||||
|
||||
LLM_COST: list[BlockCost] = []
|
||||
|
||||
|
||||
LLM_COST = (
|
||||
# Anthropic Models
|
||||
[
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
def _build_llm_costs_from_registry() -> list[BlockCost]:
|
||||
"""Build BlockCost list from all models in the LLM registry."""
|
||||
costs: list[BlockCost] = []
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
for cost in model.costs:
|
||||
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
|
||||
if not credentials:
|
||||
logger.warning(
|
||||
"Skipping cost entry for %s due to unknown credentials provider %s",
|
||||
model.slug,
|
||||
cost.credential_provider,
|
||||
)
|
||||
continue
|
||||
cost_filter = {
|
||||
"model": model.slug,
|
||||
"credentials": {
|
||||
"id": anthropic_credentials.id,
|
||||
"provider": anthropic_credentials.provider,
|
||||
"type": anthropic_credentials.type,
|
||||
"id": credentials.id,
|
||||
"provider": credentials.provider,
|
||||
"type": credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "anthropic"
|
||||
]
|
||||
# OpenAI Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
"provider": openai_credentials.provider,
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "openai"
|
||||
]
|
||||
# Groq Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {"id": groq_credentials.id},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "groq"
|
||||
]
|
||||
# Open Router Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "open_router"
|
||||
]
|
||||
# Llama API Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": llama_api_credentials.id,
|
||||
"provider": llama_api_credentials.provider,
|
||||
"type": llama_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "llama_api"
|
||||
]
|
||||
# v0 by Vercel Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": v0_credentials.id,
|
||||
"provider": v0_credentials.provider,
|
||||
"type": v0_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "v0"
|
||||
]
|
||||
# AI/ML Api Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": aiml_api_credentials.id,
|
||||
"provider": aiml_api_credentials.provider,
|
||||
"type": aiml_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "aiml_api"
|
||||
]
|
||||
)
|
||||
}
|
||||
costs.append(
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter=cost_filter,
|
||||
cost_amount=cost.credit_cost,
|
||||
)
|
||||
)
|
||||
return costs
|
||||
|
||||
|
||||
def refresh_llm_costs() -> None:
|
||||
"""Refresh LLM costs from the registry. All costs now come from the database."""
|
||||
LLM_COST.clear()
|
||||
LLM_COST.extend(_build_llm_costs_from_registry())
|
||||
|
||||
|
||||
# Initial load will happen after registry is refreshed at startup
|
||||
# Don't call refresh_llm_costs() here - it will be called after registry refresh
|
||||
|
||||
# =============== This is the exhaustive list of cost for each Block =============== #
|
||||
|
||||
|
||||
8
autogpt_platform/backend/backend/data/llm_model_types.py
Normal file
8
autogpt_platform/backend/backend/data/llm_model_types.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
|
||||
214
autogpt_platform/backend/backend/data/llm_registry.py
Normal file
214
autogpt_platform/backend/backend/data/llm_registry.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_model_types import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
|
||||
|
||||
_static_metadata: dict[str, ModelMetadata] = {}
|
||||
_static_costs: dict[str, int] = {}
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_discriminator_mapping: dict[str, str] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
|
||||
_static_metadata.update({str(key): value for key, value in metadata.items()})
|
||||
_refresh_cached_schema()
|
||||
|
||||
|
||||
def register_static_costs(costs: dict[Any, int]) -> None:
|
||||
_static_costs.update({str(key): value for key, value in costs.items()})
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
options: list[dict[str, str]] = []
|
||||
# Only include enabled models in the dropdown options
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
options.append(
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
)
|
||||
|
||||
for slug, metadata in _static_metadata.items():
|
||||
if slug in _dynamic_models:
|
||||
continue
|
||||
options.append(
|
||||
{
|
||||
"label": slug,
|
||||
"value": slug,
|
||||
"group": metadata.provider,
|
||||
"description": "",
|
||||
}
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
}
|
||||
)
|
||||
logger.debug("Found %d LLM model records in database", len(records))
|
||||
except Exception as exc:
|
||||
logger.error("Failed to refresh LLM registry from DB: %s", exc, exc_info=True)
|
||||
return
|
||||
|
||||
dynamic: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
)
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata or {},
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
dynamic[record.slug] = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=record.capabilities or {},
|
||||
extra_metadata=record.metadata or {},
|
||||
provider_display_name=record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId,
|
||||
is_enabled=record.isEnabled,
|
||||
costs=costs,
|
||||
)
|
||||
|
||||
_dynamic_models.clear()
|
||||
_dynamic_models.update(dynamic)
|
||||
_refresh_cached_schema()
|
||||
logger.info(
|
||||
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
|
||||
len(dynamic),
|
||||
sum(1 for m in dynamic.values() if m.is_enabled),
|
||||
sum(1 for m in dynamic.values() if not m.is_enabled),
|
||||
)
|
||||
|
||||
|
||||
def _refresh_cached_schema() -> None:
|
||||
new_options = _build_schema_options()
|
||||
_schema_options.clear()
|
||||
_schema_options.extend(new_options)
|
||||
_discriminator_mapping.clear()
|
||||
_discriminator_mapping.update(
|
||||
{slug: entry.metadata.provider for slug, entry in _dynamic_models.items()}
|
||||
)
|
||||
for slug, metadata in _static_metadata.items():
|
||||
_discriminator_mapping.setdefault(slug, metadata.provider)
|
||||
|
||||
|
||||
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
|
||||
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].metadata
|
||||
return _static_metadata.get(slug)
|
||||
|
||||
|
||||
# Removed get_llm_model_metadata_async - direct database queries don't work in executor context
|
||||
# The registry should be refreshed on startup via initialize_blocks() or rest_api lifespan
|
||||
|
||||
|
||||
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].costs
|
||||
cost_value = _static_costs.get(slug)
|
||||
if cost_value is None:
|
||||
return tuple()
|
||||
return (
|
||||
RegistryModelCost(
|
||||
credit_cost=cost_value,
|
||||
credential_provider="static",
|
||||
credential_id=None,
|
||||
credential_type=None,
|
||||
currency=None,
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_llm_model_schema_options() -> list[dict[str, str]]:
|
||||
"""
|
||||
Get schema options for LLM model selection dropdown.
|
||||
Always rebuilds from current registry state to ensure enabled/disabled status is current.
|
||||
"""
|
||||
# Always rebuild to ensure we have the latest enabled/disabled status
|
||||
# This is called when generating block schemas, so we need fresh data
|
||||
_refresh_cached_schema()
|
||||
return _schema_options
|
||||
|
||||
|
||||
def get_llm_discriminator_mapping() -> dict[str, str]:
|
||||
"""
|
||||
Get discriminator mapping for LLM models.
|
||||
Always rebuilds from current registry state to ensure it's current.
|
||||
"""
|
||||
# Always rebuild to ensure we have the latest mapping
|
||||
_refresh_cached_schema()
|
||||
return _discriminator_mapping
|
||||
|
||||
|
||||
def get_dynamic_model_slugs() -> set[str]:
|
||||
return set(_dynamic_models.keys())
|
||||
|
||||
|
||||
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||
return tuple(_dynamic_models.values())
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Redis pub/sub notifications for LLM registry updates.
|
||||
|
||||
When models are added/updated/removed via the admin UI, this module
|
||||
publishes notifications to Redis that all executor services subscribe to,
|
||||
ensuring they refresh their registry cache in real-time.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.redis_client import get_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis channel name for LLM registry refresh notifications
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
def publish_registry_refresh_notification() -> None:
|
||||
"""
|
||||
Publish a notification to Redis that the LLM registry has been updated.
|
||||
All executor services subscribed to this channel will refresh their registry.
|
||||
"""
|
||||
try:
|
||||
redis = get_redis()
|
||||
redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.info("Published LLM registry refresh notification to Redis")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to publish LLM registry refresh notification: %s", exc, exc_info=True
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Any, # Async callable that takes no args
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to Redis notifications for LLM registry updates.
|
||||
This runs in a loop and processes messages as they arrive.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable to execute when a refresh notification is received
|
||||
"""
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
try:
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh notifications on channel: %s", REGISTRY_REFRESH_CHANNEL)
|
||||
|
||||
# Process messages in a loop
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if message and message["type"] == "message" and message["channel"] == REGISTRY_REFRESH_CHANNEL:
|
||||
logger.info("Received LLM registry refresh notification")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Error refreshing LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", exc, exc_info=True
|
||||
)
|
||||
# Continue listening even if one message fails
|
||||
await asyncio.sleep(1)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to subscribe to LLM registry refresh notifications: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -531,6 +531,61 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
|
||||
# For LLM credentials fields, ensure discriminator and discriminator_mapping are populated
|
||||
# This handles the case where the mapping was empty at field definition time
|
||||
# Check all properties for credentials fields
|
||||
properties = schema.get("properties", {})
|
||||
for field_name, field_schema in properties.items():
|
||||
if not isinstance(field_schema, dict):
|
||||
continue
|
||||
|
||||
# Check if this is a credentials field (has credentials_provider)
|
||||
if "credentials_provider" in field_schema:
|
||||
# Check if this field should have a discriminator (multiple providers)
|
||||
providers = field_schema.get("credentials_provider", [])
|
||||
if isinstance(providers, list) and len(providers) > 1:
|
||||
# This is a multi-provider credentials field - ensure discriminator is set
|
||||
if "discriminator" not in field_schema:
|
||||
# Try to get discriminator from the field definition
|
||||
# Check if this is an LLM credentials field by looking at the model fields
|
||||
discriminator_found = False
|
||||
try:
|
||||
if hasattr(model_class, "model_fields") and field_name in model_class.model_fields:
|
||||
field_info = model_class.model_fields[field_name]
|
||||
# Check json_schema_extra on the FieldInfo
|
||||
if hasattr(field_info, "json_schema_extra"):
|
||||
if isinstance(field_info.json_schema_extra, dict):
|
||||
discriminator = field_info.json_schema_extra.get("discriminator")
|
||||
if discriminator:
|
||||
field_schema["discriminator"] = discriminator
|
||||
discriminator_found = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If still not found and this looks like an LLM field (has multiple AI providers),
|
||||
# default to "model" discriminator
|
||||
if not discriminator_found:
|
||||
# Check if providers include common LLM providers
|
||||
llm_providers = {"openai", "anthropic", "groq", "open_router", "llama_api", "aiml_api", "v0"}
|
||||
if any(p in llm_providers for p in providers):
|
||||
field_schema["discriminator"] = "model"
|
||||
|
||||
# If discriminator is "model", ensure discriminator_mapping is populated
|
||||
if field_schema.get("discriminator") == "model":
|
||||
mapping = field_schema.get("discriminator_mapping")
|
||||
# If mapping is empty, missing, or None, refresh from registry
|
||||
if not mapping or (isinstance(mapping, dict) and len(mapping) == 0):
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
refreshed_mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
if refreshed_mapping:
|
||||
field_schema["discriminator_mapping"] = refreshed_mapping
|
||||
except Exception:
|
||||
# If registry isn't available, ensure at least an empty dict is present
|
||||
if "discriminator_mapping" not in field_schema:
|
||||
field_schema["discriminator_mapping"] = {}
|
||||
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -680,16 +735,20 @@ def CredentialsField(
|
||||
This is enforced by the `BlockSchema` base class.
|
||||
"""
|
||||
|
||||
field_schema_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
|
||||
field_schema_extra: dict[str, Any] = {}
|
||||
|
||||
# Always include discriminator if provided
|
||||
if discriminator is not None:
|
||||
field_schema_extra["discriminator"] = discriminator
|
||||
# Always include discriminator_mapping when discriminator is set (even if empty initially)
|
||||
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
|
||||
|
||||
# Include other optional fields (only if not None)
|
||||
if required_scopes:
|
||||
field_schema_extra["credentials_scopes"] = list(required_scopes)
|
||||
if discriminator_values:
|
||||
field_schema_extra["discriminator_values"] = discriminator_values
|
||||
|
||||
# Merge any json_schema_extra passed in kwargs
|
||||
if "json_schema_extra" in kwargs:
|
||||
|
||||
@@ -621,6 +621,72 @@ class ExecutionProcessor:
|
||||
)
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
|
||||
# Initialize blocks and refresh LLM registry in the execution loop
|
||||
async def init_registry():
|
||||
try:
|
||||
from backend.data import db, llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.block import initialize_blocks
|
||||
|
||||
# Connect to database for registry refresh
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
logger.info("[GraphExecutor] Connected to database for registry refresh")
|
||||
|
||||
# Refresh LLM registry before initializing blocks
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
logger.info("[GraphExecutor] LLM registry refreshed")
|
||||
|
||||
# Initialize blocks (this also refreshes registry, but we do it explicitly above)
|
||||
await initialize_blocks()
|
||||
logger.info("[GraphExecutor] Blocks initialized")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Refresh registry function for notifications
|
||||
async def refresh_registry_from_notification():
|
||||
"""Refresh LLM registry when notified via Redis pub/sub"""
|
||||
try:
|
||||
from backend.data import db, llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
# Ensure DB is connected
|
||||
if not db.is_connected():
|
||||
await db.connect()
|
||||
|
||||
# Refresh registry and costs
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they regenerate with new model options
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
logger.info("[GraphExecutor] LLM registry refreshed from notification")
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Schedule registry refresh in the execution loop
|
||||
asyncio.run_coroutine_threadsafe(init_registry(), self.node_execution_loop)
|
||||
|
||||
# Subscribe to registry refresh notifications
|
||||
async def subscribe_to_refresh():
|
||||
from backend.data.llm_registry_notifications import subscribe_to_registry_refresh
|
||||
await subscribe_to_registry_refresh(refresh_registry_from_notification)
|
||||
|
||||
# Start subscription in a background task
|
||||
asyncio.run_coroutine_threadsafe(subscribe_to_refresh(), self.node_execution_loop)
|
||||
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
|
||||
@@ -122,6 +122,24 @@ class ConnectionManager:
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
|
||||
"""Broadcast a message to all active websocket connections."""
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.active_connections)
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
|
||||
@@ -20,11 +20,14 @@ import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.execution_analytics_routes
|
||||
import backend.server.v2.admin.llm_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.builder
|
||||
import backend.server.v2.builder.routes
|
||||
@@ -36,6 +39,7 @@ import backend.server.v2.library.routes
|
||||
import backend.server.v2.otto.routes
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.server.v2.llm.routes as public_llm_routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
@@ -104,11 +108,20 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
||||
from backend.data.block import BlockSchema
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
# Note: migrate_llm_models may need updating if it references LlmModel enum
|
||||
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
@@ -280,6 +293,16 @@ app.include_router(
|
||||
tags=["v2", "executions", "review"],
|
||||
prefix="/api/review",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.admin.llm_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/llm",
|
||||
)
|
||||
app.include_router(
|
||||
public_llm_routes.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
|
||||
145
autogpt_platform/backend/backend/server/v2/admin/llm_routes.py
Normal file
145
autogpt_platform/backend/backend/server/v2/admin/llm_routes.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/admin/llm",
|
||||
tags=["llm", "admin"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
async def _refresh_runtime_state() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||
logger.info("Refreshing LLM registry runtime state...")
|
||||
|
||||
# Refresh registry from database
|
||||
await llm_registry.refresh_llm_registry()
|
||||
refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated model options
|
||||
from backend.data.block import BlockSchema
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
logger.info("Cleared all block schema caches")
|
||||
|
||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||
try:
|
||||
from backend.server.routers.v1 import _get_cached_blocks
|
||||
_get_cached_blocks.cache_clear()
|
||||
logger.info("Cleared /blocks endpoint cache")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear /blocks cache: %s", e)
|
||||
|
||||
# Clear the v2 builder providers cache (if it exists)
|
||||
try:
|
||||
from backend.server.v2.builder import db as builder_db
|
||||
if hasattr(builder_db, '_get_all_providers'):
|
||||
builder_db._get_all_providers.cache_clear()
|
||||
logger.info("Cleared v2 builder providers cache")
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||
|
||||
# Notify all executor services to refresh their registry cache
|
||||
from backend.data.llm_registry_notifications import publish_registry_refresh_notification
|
||||
publish_registry_refresh_notification()
|
||||
logger.info("Published registry refresh notification")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="List LLM providers",
|
||||
response_model=llm_model.LlmProvidersResponse,
|
||||
)
|
||||
async def list_llm_providers(include_models: bool = True):
|
||||
providers = await llm_db.list_providers(include_models=include_models)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/providers",
|
||||
summary="Create LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
|
||||
provider = await llm_db.upsert_provider(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/providers/{provider_id}",
|
||||
summary="Update LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def update_llm_provider(
|
||||
provider_id: str,
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
):
|
||||
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List LLM models",
|
||||
response_model=llm_model.LlmModelsResponse,
|
||||
)
|
||||
async def list_llm_models(provider_id: str | None = fastapi.Query(default=None)):
|
||||
models = await llm_db.list_models(provider_id=provider_id)
|
||||
return llm_model.LlmModelsResponse(models=models)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
summary="Create LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
|
||||
model = await llm_db.create_model(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}",
|
||||
summary="Update LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def update_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
):
|
||||
model = await llm_db.update_model(model_id=model_id, request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}/toggle",
|
||||
summary="Toggle LLM model availability",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def toggle_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.ToggleLlmModelRequest,
|
||||
):
|
||||
try:
|
||||
model = await llm_db.toggle_model(model_id=model_id, is_enabled=request.is_enabled)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to toggle model availability",
|
||||
) from exc
|
||||
|
||||
205
autogpt_platform/backend/backend/server/v2/llm/db.py
Normal file
205
autogpt_platform/backend/backend/server/v2/llm/db.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
|
||||
def _json_dict(value: Any | None) -> dict[str, Any]:
|
||||
if not value:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return {}
|
||||
|
||||
|
||||
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
|
||||
return llm_model.LlmModelCost(
|
||||
id=record.id,
|
||||
unit=record.unit,
|
||||
credit_cost=record.creditCost,
|
||||
credential_provider=record.credentialProvider,
|
||||
credential_id=record.credentialId,
|
||||
credential_type=record.credentialType,
|
||||
currency=record.currency,
|
||||
metadata=_json_dict(record.metadata),
|
||||
)
|
||||
|
||||
|
||||
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
|
||||
costs = []
|
||||
if record.Costs:
|
||||
costs = [_map_cost(cost) for cost in record.Costs]
|
||||
|
||||
return llm_model.LlmModel(
|
||||
id=record.id,
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
provider_id=record.providerId,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
is_enabled=record.isEnabled,
|
||||
capabilities=_json_dict(record.capabilities),
|
||||
metadata=_json_dict(record.metadata),
|
||||
costs=costs,
|
||||
)
|
||||
|
||||
|
||||
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
|
||||
models: list[llm_model.LlmModel] = []
|
||||
if record.Models:
|
||||
models = [_map_model(model) for model in record.Models]
|
||||
|
||||
return llm_model.LlmProvider(
|
||||
id=record.id,
|
||||
name=record.name,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
default_credential_provider=record.defaultCredentialProvider,
|
||||
default_credential_id=record.defaultCredentialId,
|
||||
default_credential_type=record.defaultCredentialType,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool=record.supportsParallelTool,
|
||||
metadata=_json_dict(record.metadata),
|
||||
models=models,
|
||||
)
|
||||
|
||||
|
||||
async def list_providers(include_models: bool = True) -> list[llm_model.LlmProvider]:
|
||||
include = (
|
||||
{"Models": {"include": {"Costs": True}}}
|
||||
if include_models
|
||||
else None
|
||||
)
|
||||
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
|
||||
return [_map_provider(record) for record in records]
|
||||
|
||||
|
||||
async def upsert_provider(
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
provider_id: str | None = None,
|
||||
) -> llm_model.LlmProvider:
|
||||
data = {
|
||||
"name": request.name,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"defaultCredentialProvider": request.default_credential_provider,
|
||||
"defaultCredentialId": request.default_credential_id,
|
||||
"defaultCredentialType": request.default_credential_type,
|
||||
"supportsTools": request.supports_tools,
|
||||
"supportsJsonOutput": request.supports_json_output,
|
||||
"supportsReasoning": request.supports_reasoning,
|
||||
"supportsParallelTool": request.supports_parallel_tool,
|
||||
"metadata": request.metadata,
|
||||
}
|
||||
if provider_id:
|
||||
record = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include={"Models": {"include": {"Costs": True}}},
|
||||
)
|
||||
else:
|
||||
record = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include={"Models": {"include": {"Costs": True}}},
|
||||
)
|
||||
return _map_provider(record)
|
||||
|
||||
|
||||
async def list_models(provider_id: str | None = None) -> list[llm_model.LlmModel]:
|
||||
where = {"providerId": provider_id} if provider_id else None
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where,
|
||||
include={"Costs": True},
|
||||
)
|
||||
return [_map_model(record) for record in records]
|
||||
|
||||
|
||||
def _cost_create_payload(
|
||||
costs: Sequence[llm_model.LlmModelCostInput],
|
||||
) -> dict[str, Iterable[dict[str, Any]]]:
|
||||
return {
|
||||
"create": [
|
||||
{
|
||||
"unit": cost.unit,
|
||||
"creditCost": cost.credit_cost,
|
||||
"credentialProvider": cost.credential_provider,
|
||||
"credentialId": cost.credential_id,
|
||||
"credentialType": cost.credential_type,
|
||||
"currency": cost.currency,
|
||||
"metadata": cost.metadata,
|
||||
}
|
||||
for cost in costs
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def create_model(
|
||||
request: llm_model.CreateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
record = await prisma.models.LlmModel.prisma().create(
|
||||
data={
|
||||
"slug": request.slug,
|
||||
"displayName": request.display_name,
|
||||
"description": request.description,
|
||||
"providerId": request.provider_id,
|
||||
"contextWindow": request.context_window,
|
||||
"maxOutputTokens": request.max_output_tokens,
|
||||
"isEnabled": request.is_enabled,
|
||||
"capabilities": request.capabilities,
|
||||
"metadata": request.metadata,
|
||||
"Costs": _cost_create_payload(request.costs),
|
||||
},
|
||||
include={"Costs": True},
|
||||
)
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
) -> llm_model.LlmModel:
|
||||
data: dict[str, Any] = {}
|
||||
if request.display_name is not None:
|
||||
data["displayName"] = request.display_name
|
||||
if request.description is not None:
|
||||
data["description"] = request.description
|
||||
if request.context_window is not None:
|
||||
data["contextWindow"] = request.context_window
|
||||
if request.max_output_tokens is not None:
|
||||
data["maxOutputTokens"] = request.max_output_tokens
|
||||
if request.is_enabled is not None:
|
||||
data["isEnabled"] = request.is_enabled
|
||||
if request.capabilities is not None:
|
||||
data["capabilities"] = request.capabilities
|
||||
if request.metadata is not None:
|
||||
data["metadata"] = request.metadata
|
||||
if request.provider_id is not None:
|
||||
data["providerId"] = request.provider_id
|
||||
if request.costs is not None:
|
||||
data["Costs"] = {
|
||||
"deleteMany": {"llmModelId": model_id},
|
||||
**_cost_create_payload(request.costs),
|
||||
}
|
||||
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data=data,
|
||||
include={"Costs": True},
|
||||
)
|
||||
return _map_model(record)
|
||||
|
||||
|
||||
async def toggle_model(model_id: str, is_enabled: bool) -> llm_model.LlmModel:
|
||||
record = await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
include={"Costs": True},
|
||||
)
|
||||
return _map_model(record)
|
||||
|
||||
109
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
109
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
id: str
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = None
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
id: str
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = None
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
providers: list[LlmProvider]
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
models: list[LlmModel]
|
||||
|
||||
|
||||
class UpsertLlmProviderRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
default_credential_provider: Optional[str] = None
|
||||
default_credential_id: Optional[str] = None
|
||||
default_credential_type: Optional[str] = "api_key"
|
||||
supports_tools: bool = True
|
||||
supports_json_output: bool = True
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool: bool = False
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCostInput(pydantic.BaseModel):
|
||||
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_type: Optional[str] = "api_key"
|
||||
currency: Optional[str] = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class CreateLlmModelRequest(pydantic.BaseModel):
|
||||
slug: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
provider_id: str
|
||||
context_window: int
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: bool = True
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCostInput]
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(pydantic.BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
context_window: Optional[int] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
capabilities: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
provider_id: Optional[str] = None
|
||||
costs: Optional[list[LlmModelCostInput]] = None
|
||||
|
||||
|
||||
class ToggleLlmModelRequest(pydantic.BaseModel):
|
||||
is_enabled: bool
|
||||
|
||||
24
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
24
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models():
|
||||
models = await llm_db.list_models()
|
||||
return llm_model.LlmModelsResponse(models=models)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
providers = await llm_db.list_providers(include_models=True)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
@@ -77,7 +77,32 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
async def registry_refresh_worker():
|
||||
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
|
||||
from backend.data.llm_registry_notifications import REGISTRY_REFRESH_CHANNEL
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh notifications for WebSocket broadcast")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message" and message["channel"] == REGISTRY_REFRESH_CHANNEL:
|
||||
logger.info("Broadcasting LLM registry refresh to all WebSocket clients")
|
||||
await manager.broadcast_to_all(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={
|
||||
"type": "LLM_REGISTRY_REFRESH",
|
||||
"event": "registry_updated",
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
execution_worker(),
|
||||
notification_worker(),
|
||||
registry_refresh_worker(),
|
||||
)
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
|
||||
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
|
||||
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models
|
||||
('o3', 'O3', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1047576, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
|
||||
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
|
||||
-- Anthropic models
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
|
||||
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
|
||||
-- Groq models
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
|
||||
-- Ollama models
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
|
||||
-- OpenRouter models
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
|
||||
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
|
||||
-- Llama API models
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
|
||||
-- v0 models
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
('gpt-3.5-turbo', 1),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-7-sonnet-20250219', 5),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-3-pro-preview', 5),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId";
|
||||
|
||||
@@ -954,3 +954,84 @@ enum APIKeyStatus {
|
||||
REVOKED
|
||||
SUSPENDED
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
///////////// LLM REGISTRY AND BILLING DATA /////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
supportsTools Boolean @default(true)
|
||||
supportsJsonOutput Boolean @default(true)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelTool Boolean @default(false)
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
isEnabled Boolean @default(true)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([slug])
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int
|
||||
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([llmModelId])
|
||||
@@index([credentialProvider])
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Users, DollarSign, UserSearch, FileText, Cpu } from "lucide-react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -26,6 +26,11 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "LLM Registry",
|
||||
href: "/admin/llms",
|
||||
icon: <Cpu className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
"use server";
|
||||
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
import type {
|
||||
CreateLlmModelRequest,
|
||||
LlmModelsResponse,
|
||||
LlmProvidersResponse,
|
||||
ToggleLlmModelRequest,
|
||||
UpdateLlmModelRequest,
|
||||
UpsertLlmProviderRequest,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { revalidatePath } from "next/cache";
|
||||
|
||||
const ADMIN_LLM_PATH = "/admin/llms";
|
||||
|
||||
export async function fetchLlmProviders(): Promise<LlmProvidersResponse> {
|
||||
const api = new BackendApi();
|
||||
return await api.listAdminLlmProviders(true);
|
||||
}
|
||||
|
||||
export async function fetchLlmModels(): Promise<LlmModelsResponse> {
|
||||
const api = new BackendApi();
|
||||
return await api.listAdminLlmModels();
|
||||
}
|
||||
|
||||
export async function createLlmProviderAction(formData: FormData) {
|
||||
const payload: UpsertLlmProviderRequest = {
|
||||
name: String(formData.get("name") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
default_credential_provider: formData.get("default_credential_provider")
|
||||
? String(formData.get("default_credential_provider"))
|
||||
: undefined,
|
||||
default_credential_id: formData.get("default_credential_id")
|
||||
? String(formData.get("default_credential_id"))
|
||||
: undefined,
|
||||
default_credential_type: formData.get("default_credential_type")
|
||||
? String(formData.get("default_credential_type"))
|
||||
: undefined,
|
||||
supports_tools: formData.get("supports_tools") === "on",
|
||||
supports_json_output: formData.get("supports_json_output") !== "off",
|
||||
supports_reasoning: formData.get("supports_reasoning") === "on",
|
||||
supports_parallel_tool: formData.get("supports_parallel_tool") === "on",
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const api = new BackendApi();
|
||||
await api.createAdminLlmProvider(payload);
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function createLlmModelAction(formData: FormData) {
|
||||
const payload: CreateLlmModelRequest = {
|
||||
slug: String(formData.get("slug") || "").trim(),
|
||||
display_name: String(formData.get("display_name") || "").trim(),
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: String(formData.get("provider_id")),
|
||||
context_window: Number(formData.get("context_window") || 0),
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.get("is_enabled") !== "off",
|
||||
capabilities: {},
|
||||
metadata: {},
|
||||
costs: [
|
||||
{
|
||||
credit_cost: Number(formData.get("credit_cost") || 0),
|
||||
credential_provider: String(
|
||||
formData.get("credential_provider") || "",
|
||||
).trim(),
|
||||
credential_id: formData.get("credential_id")
|
||||
? String(formData.get("credential_id"))
|
||||
: undefined,
|
||||
credential_type: formData.get("credential_type")
|
||||
? String(formData.get("credential_type"))
|
||||
: undefined,
|
||||
metadata: {},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const api = new BackendApi();
|
||||
await api.createAdminLlmModel(payload);
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function updateLlmModelAction(formData: FormData) {
|
||||
const modelId = String(formData.get("model_id"));
|
||||
const payload: UpdateLlmModelRequest = {
|
||||
display_name: formData.get("display_name")
|
||||
? String(formData.get("display_name"))
|
||||
: undefined,
|
||||
description: formData.get("description")
|
||||
? String(formData.get("description"))
|
||||
: undefined,
|
||||
provider_id: formData.get("provider_id")
|
||||
? String(formData.get("provider_id"))
|
||||
: undefined,
|
||||
context_window: formData.get("context_window")
|
||||
? Number(formData.get("context_window"))
|
||||
: undefined,
|
||||
max_output_tokens: formData.get("max_output_tokens")
|
||||
? Number(formData.get("max_output_tokens"))
|
||||
: undefined,
|
||||
is_enabled: formData.get("is_enabled")
|
||||
? formData.get("is_enabled") === "on"
|
||||
: undefined,
|
||||
costs: formData.get("credit_cost")
|
||||
? [
|
||||
{
|
||||
credit_cost: Number(formData.get("credit_cost")),
|
||||
credential_provider: String(
|
||||
formData.get("credential_provider") || "",
|
||||
).trim(),
|
||||
credential_id: formData.get("credential_id")
|
||||
? String(formData.get("credential_id"))
|
||||
: undefined,
|
||||
credential_type: formData.get("credential_type")
|
||||
? String(formData.get("credential_type"))
|
||||
: undefined,
|
||||
metadata: {},
|
||||
},
|
||||
]
|
||||
: undefined,
|
||||
};
|
||||
|
||||
const api = new BackendApi();
|
||||
await api.updateAdminLlmModel(modelId, payload);
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
export async function toggleLlmModelAction(formData: FormData) {
|
||||
const modelId = String(formData.get("model_id"));
|
||||
const shouldEnable = formData.get("is_enabled") === "true";
|
||||
const payload: ToggleLlmModelRequest = {
|
||||
is_enabled: shouldEnable,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
await api.toggleAdminLlmModel(modelId, payload);
|
||||
revalidatePath(ADMIN_LLM_PATH);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
import type { LlmProvider } from "@/lib/autogpt-server-api/types";
|
||||
import { createLlmModelAction } from "../actions";
|
||||
|
||||
export function AddModelForm({ providers }: { providers: LlmProvider[] }) {
|
||||
return (
|
||||
<form
|
||||
action={createLlmModelAction}
|
||||
className="space-y-8 rounded-lg border border-border bg-card p-8 shadow-sm"
|
||||
>
|
||||
<div className="space-y-2">
|
||||
<h2 className="text-2xl font-semibold tracking-tight">Add Model</h2>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Register a new model slug, metadata, and pricing.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-8">
|
||||
<div className="space-y-5">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-base font-semibold text-foreground">Basic Information</h3>
|
||||
<p className="text-xs text-muted-foreground">Core model details</p>
|
||||
</div>
|
||||
<div className="grid gap-5 md:grid-cols-2">
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Model Slug</span>
|
||||
<input
|
||||
required
|
||||
name="slug"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="gpt-4.1-mini-2025-04-14"
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Display Name</span>
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="GPT 4.1 Mini"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Description</span>
|
||||
<textarea
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="space-y-5 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-base font-semibold text-foreground">Model Configuration</h3>
|
||||
<p className="text-xs text-muted-foreground">Model capabilities and limits</p>
|
||||
</div>
|
||||
<div className="grid gap-5 md:grid-cols-3">
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Provider</span>
|
||||
<select
|
||||
required
|
||||
name="provider_id"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
defaultValue=""
|
||||
>
|
||||
<option value="" disabled>
|
||||
Select provider
|
||||
</option>
|
||||
{providers.map((provider) => (
|
||||
<option key={provider.id} value={provider.id}>
|
||||
{provider.display_name} ({provider.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Context Window</span>
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="128000"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Max Output Tokens</span>
|
||||
<input
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="16384"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-5 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-base font-semibold text-foreground">Pricing & Credentials</h3>
|
||||
<p className="text-xs text-muted-foreground">Cost and credential configuration</p>
|
||||
</div>
|
||||
<div className="grid gap-5 md:grid-cols-2 lg:grid-cols-4">
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Credit Cost</span>
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="5"
|
||||
min={0}
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Credential Provider</span>
|
||||
<input
|
||||
required
|
||||
name="credential_provider"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Credential ID</span>
|
||||
<input
|
||||
name="credential_id"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="cred-id"
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Credential Type</span>
|
||||
<input
|
||||
name="credential_type"
|
||||
defaultValue="api_key"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
<p className="mt-3 text-xs text-muted-foreground">
|
||||
Credit cost is always in platform credits.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3 border-t border-border pt-6">
|
||||
<label className="flex items-center gap-3 text-sm font-medium">
|
||||
<input type="hidden" name="is_enabled" value="off" />
|
||||
<input
|
||||
type="checkbox"
|
||||
name="is_enabled"
|
||||
defaultChecked
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
Enabled by default
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end border-t border-border pt-6">
|
||||
<button
|
||||
type="submit"
|
||||
className="inline-flex items-center rounded-md bg-primary px-8 py-3 text-sm font-semibold text-primary-foreground shadow-sm transition-colors hover:bg-primary/90"
|
||||
>
|
||||
Save Model
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import { createLlmProviderAction } from "../actions";
|
||||
|
||||
export function AddProviderForm() {
|
||||
return (
|
||||
<form
|
||||
action={createLlmProviderAction}
|
||||
className="space-y-8 rounded-lg border border-border bg-card p-8 shadow-sm"
|
||||
>
|
||||
<div className="space-y-2">
|
||||
<h2 className="text-2xl font-semibold tracking-tight">Add Provider</h2>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Define a new upstream provider and default credential information.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-8">
|
||||
<div className="space-y-5">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-base font-semibold text-foreground">Basic Information</h3>
|
||||
<p className="text-xs text-muted-foreground">Core provider details</p>
|
||||
</div>
|
||||
<div className="grid gap-5 md:grid-cols-2">
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Provider Slug</span>
|
||||
<input
|
||||
required
|
||||
name="name"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="e.g. openai"
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Display Name</span>
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="OpenAI"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Description</span>
|
||||
<textarea
|
||||
name="description"
|
||||
rows={3}
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="space-y-5 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-base font-semibold text-foreground">Default Credentials</h3>
|
||||
<p className="text-xs text-muted-foreground">Default credential configuration</p>
|
||||
</div>
|
||||
<div className="grid gap-5 md:grid-cols-3">
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Credential Provider</span>
|
||||
<input
|
||||
name="default_credential_provider"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="openai"
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Credential ID</span>
|
||||
<input
|
||||
name="default_credential_id"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
placeholder="cred-id"
|
||||
/>
|
||||
</label>
|
||||
<label className="space-y-2.5">
|
||||
<span className="text-sm font-medium text-foreground">Credential Type</span>
|
||||
<input
|
||||
name="default_credential_type"
|
||||
defaultValue="api_key"
|
||||
className="w-full rounded-md border border-input bg-background px-4 py-2.5 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-5 border-t border-border pt-6">
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-base font-semibold text-foreground">Capabilities</h3>
|
||||
<p className="text-xs text-muted-foreground">Provider feature flags</p>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
{[
|
||||
{ name: "supports_tools", label: "Supports tools" },
|
||||
{ name: "supports_json_output", label: "Supports JSON output" },
|
||||
{ name: "supports_reasoning", label: "Supports reasoning" },
|
||||
{ name: "supports_parallel_tool", label: "Supports parallel tool calls" },
|
||||
].map(({ name, label }) => (
|
||||
<label key={name} className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 text-sm font-medium transition-colors hover:bg-muted/50">
|
||||
<input type="hidden" name={name} value="off" />
|
||||
<input
|
||||
type="checkbox"
|
||||
name={name}
|
||||
defaultChecked={name !== "supports_reasoning" && name !== "supports_parallel_tool"}
|
||||
className="h-4 w-4 rounded border-input"
|
||||
/>
|
||||
{label}
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end border-t border-border pt-6">
|
||||
<button
|
||||
type="submit"
|
||||
className="inline-flex items-center rounded-md bg-primary px-8 py-3 text-sm font-semibold text-primary-foreground shadow-sm transition-colors hover:bg-primary/90"
|
||||
>
|
||||
Save Provider
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import type { LlmModel, LlmProvider } from "@/lib/autogpt-server-api/types";
|
||||
import { updateLlmModelAction } from "../actions";
|
||||
|
||||
export function EditModelModal({
|
||||
model,
|
||||
providers,
|
||||
}: {
|
||||
model: LlmModel;
|
||||
providers: LlmProvider[];
|
||||
}) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const cost = model.costs[0];
|
||||
const provider = providers.find((p) => p.id === model.provider_id);
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={setOpen}>
|
||||
<DialogTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className="inline-flex items-center rounded border border-input px-3 py-1 text-xs font-semibold hover:bg-muted"
|
||||
>
|
||||
Edit
|
||||
</button>
|
||||
</DialogTrigger>
|
||||
<DialogContent className="max-w-2xl max-h-[90vh] overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Edit Model</DialogTitle>
|
||||
<DialogDescription>
|
||||
Update model metadata and pricing information.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<form
|
||||
action={async (formData) => {
|
||||
await updateLlmModelAction(formData);
|
||||
setOpen(false);
|
||||
}}
|
||||
className="space-y-4"
|
||||
>
|
||||
<input type="hidden" name="model_id" value={model.id} />
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Display Name
|
||||
<input
|
||||
required
|
||||
name="display_name"
|
||||
defaultValue={model.display_name}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Provider
|
||||
<select
|
||||
required
|
||||
name="provider_id"
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
defaultValue={model.provider_id}
|
||||
>
|
||||
{providers.map((p) => (
|
||||
<option key={p.id} value={p.id}>
|
||||
{p.display_name} ({p.name})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label className="text-sm font-medium">
|
||||
Description
|
||||
<textarea
|
||||
name="description"
|
||||
rows={2}
|
||||
defaultValue={model.description ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
placeholder="Optional description..."
|
||||
/>
|
||||
</label>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<label className="text-sm font-medium">
|
||||
Context Window
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="context_window"
|
||||
defaultValue={model.context_window}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Max Output Tokens
|
||||
<input
|
||||
type="number"
|
||||
name="max_output_tokens"
|
||||
defaultValue={model.max_output_tokens ?? undefined}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={1}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-4">
|
||||
<label className="text-sm font-medium">
|
||||
Credit Cost
|
||||
<input
|
||||
required
|
||||
type="number"
|
||||
name="credit_cost"
|
||||
defaultValue={cost?.credit_cost ?? 0}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
min={0}
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Credential Provider
|
||||
<input
|
||||
required
|
||||
name="credential_provider"
|
||||
defaultValue={cost?.credential_provider ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Credential ID
|
||||
<input
|
||||
name="credential_id"
|
||||
defaultValue={cost?.credential_id ?? ""}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
placeholder="cred-id"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm font-medium">
|
||||
Credential Type
|
||||
<input
|
||||
name="credential_type"
|
||||
defaultValue={cost?.credential_type ?? "api_key"}
|
||||
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label className="flex items-center gap-2 text-sm font-medium">
|
||||
<input type="hidden" name="is_enabled" value="off" />
|
||||
<input
|
||||
type="checkbox"
|
||||
name="is_enabled"
|
||||
defaultChecked={model.is_enabled}
|
||||
/>
|
||||
Enabled
|
||||
</label>
|
||||
|
||||
<div className="flex justify-end gap-2 pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setOpen(false)}
|
||||
className="inline-flex items-center rounded border border-input px-4 py-2 text-sm font-semibold hover:bg-muted"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
className="inline-flex items-center rounded bg-primary px-4 py-2 text-sm font-semibold text-primary-foreground hover:bg-primary/90"
|
||||
>
|
||||
Update Model
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
import type { LlmModel, LlmProvider } from "@/lib/autogpt-server-api/types";
|
||||
import { toggleLlmModelAction } from "../actions";
|
||||
import { EditModelModal } from "./EditModelModal";
|
||||
|
||||
export function ModelsTable({
|
||||
models,
|
||||
providers,
|
||||
}: {
|
||||
models: LlmModel[];
|
||||
providers: LlmProvider[];
|
||||
}) {
|
||||
if (!models.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No models registered yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const providerLookup = new Map(providers.map((provider) => [provider.id, provider]));
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead>Provider</TableHead>
|
||||
<TableHead>Context Window</TableHead>
|
||||
<TableHead>Max Output</TableHead>
|
||||
<TableHead>Cost</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{models.map((model) => {
|
||||
const cost = model.costs[0];
|
||||
const provider = providerLookup.get(model.provider_id);
|
||||
return (
|
||||
<TableRow
|
||||
key={model.id}
|
||||
className={model.is_enabled ? "" : "opacity-60"}
|
||||
>
|
||||
<TableCell>
|
||||
<div className="font-medium">{model.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">{model.slug}</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{provider ? (
|
||||
<>
|
||||
<div>{provider.display_name}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{provider.name}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
model.provider_id
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>{model.context_window.toLocaleString()}</TableCell>
|
||||
<TableCell>
|
||||
{model.max_output_tokens
|
||||
? model.max_output_tokens.toLocaleString()
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{cost ? (
|
||||
<>
|
||||
<div className="font-medium">{cost.credit_cost} credits</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{cost.credential_provider}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
"—"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span
|
||||
className={`inline-flex rounded-full px-2 py-0.5 text-xs font-semibold ${
|
||||
model.is_enabled
|
||||
? "bg-green-100 text-green-700"
|
||||
: "bg-muted text-muted-foreground"
|
||||
}`}
|
||||
>
|
||||
{model.is_enabled ? "Enabled" : "Disabled"}
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell className="text-right text-sm">
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<ToggleModelButton modelId={model.id} isEnabled={model.is_enabled} />
|
||||
<EditModelModal model={model} providers={providers} />
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ToggleModelButton({
|
||||
modelId,
|
||||
isEnabled,
|
||||
}: {
|
||||
modelId: string;
|
||||
isEnabled: boolean;
|
||||
}) {
|
||||
return (
|
||||
<form action={toggleLlmModelAction}>
|
||||
<input type="hidden" name="model_id" value={modelId} />
|
||||
<input type="hidden" name="is_enabled" value={(!isEnabled).toString()} />
|
||||
<button
|
||||
type="submit"
|
||||
className="inline-flex items-center rounded border border-input px-3 py-1 text-xs font-semibold hover:bg-muted"
|
||||
>
|
||||
{isEnabled ? "Disable" : "Enable"}
|
||||
</button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
import type { LlmProvider } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export function ProviderList({ providers }: { providers: LlmProvider[] }) {
|
||||
if (!providers.length) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
|
||||
No providers configured yet.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Display Name</TableHead>
|
||||
<TableHead>Default Credential</TableHead>
|
||||
<TableHead>Capabilities</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{providers.map((provider) => (
|
||||
<TableRow key={provider.id}>
|
||||
<TableCell className="font-medium">{provider.name}</TableCell>
|
||||
<TableCell>{provider.display_name}</TableCell>
|
||||
<TableCell>
|
||||
{provider.default_credential_provider
|
||||
? `${provider.default_credential_provider} (${provider.default_credential_id ?? "id?"})`
|
||||
: "—"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-muted-foreground">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{provider.supports_tools && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Tools
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_json_output && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
JSON
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_reasoning && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Reasoning
|
||||
</span>
|
||||
)}
|
||||
{provider.supports_parallel_tool && (
|
||||
<span className="rounded bg-muted px-2 py-0.5 text-xs">
|
||||
Parallel Tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import {
|
||||
fetchLlmModels,
|
||||
fetchLlmProviders,
|
||||
} from "./actions";
|
||||
import { AddProviderForm } from "./components/AddProviderForm";
|
||||
import { AddModelForm } from "./components/AddModelForm";
|
||||
import { ProviderList } from "./components/ProviderList";
|
||||
import { ModelsTable } from "./components/ModelsTable";
|
||||
|
||||
async function LlmRegistryDashboard() {
|
||||
const [providersResponse, modelsResponse] = await Promise.all([
|
||||
fetchLlmProviders(),
|
||||
fetchLlmModels(),
|
||||
]);
|
||||
|
||||
const providers = providersResponse.providers;
|
||||
const models = modelsResponse.models;
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex w-full max-w-7xl flex-col gap-12 p-8">
|
||||
<div className="space-y-2">
|
||||
<h1 className="text-4xl font-bold tracking-tight">LLM Registry</h1>
|
||||
<p className="text-base text-muted-foreground">
|
||||
Manage supported providers, models, and credit pricing
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-8 lg:grid-cols-2">
|
||||
<AddProviderForm />
|
||||
<AddModelForm providers={providers} />
|
||||
</div>
|
||||
|
||||
<div className="space-y-6">
|
||||
<div className="space-y-2">
|
||||
<h2 className="text-3xl font-semibold tracking-tight">Providers</h2>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Default credentials and feature flags for upstream vendors.
|
||||
</p>
|
||||
</div>
|
||||
<ProviderList providers={providers} />
|
||||
</div>
|
||||
|
||||
<div className="space-y-6">
|
||||
<div className="space-y-2">
|
||||
<h2 className="text-3xl font-semibold tracking-tight">Models</h2>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Toggle availability, adjust context windows, and update credit pricing.
|
||||
</p>
|
||||
</div>
|
||||
<ModelsTable models={models} providers={providers} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function AdminLlmRegistryPage() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedDashboard = await withAdminAccess(LlmRegistryDashboard);
|
||||
return <ProtectedDashboard />;
|
||||
}
|
||||
|
||||
@@ -7,8 +7,9 @@ import { BlockCategoryResponse } from "@/app/api/__generated__/models/blockCateg
|
||||
import { BlockResponse } from "@/app/api/__generated__/models/blockResponse";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useState } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
|
||||
export const useAllBlockContent = () => {
|
||||
const { toast } = useToast();
|
||||
@@ -93,6 +94,29 @@ export const useAllBlockContent = () => {
|
||||
const isErrorOnLoadingMore = (categoryName: string) =>
|
||||
errorLoadingCategories.has(categoryName);
|
||||
|
||||
// Listen for LLM registry refresh notifications
|
||||
useEffect(() => {
|
||||
const api = new BackendApi();
|
||||
const queryClient = getQueryClient();
|
||||
|
||||
const handleNotification = (notification: any) => {
|
||||
if (
|
||||
notification?.type === "LLM_REGISTRY_REFRESH" ||
|
||||
notification?.event === "registry_updated"
|
||||
) {
|
||||
// Invalidate all block-related queries to force refresh
|
||||
const categoriesQueryKey = getGetV2GetBuilderBlockCategoriesQueryKey();
|
||||
queryClient.invalidateQueries({ queryKey: categoriesQueryKey });
|
||||
}
|
||||
};
|
||||
|
||||
const unsubscribe = api.onWebSocketMessage("notification", handleNotification);
|
||||
|
||||
return () => {
|
||||
unsubscribe();
|
||||
};
|
||||
}, []);
|
||||
|
||||
return {
|
||||
data,
|
||||
isLoading,
|
||||
|
||||
@@ -41,6 +41,13 @@ const Dot: FC<{ isConnected: boolean; type?: string }> = memo(
|
||||
);
|
||||
Dot.displayName = "Dot";
|
||||
|
||||
const getSchemaType = (schema: BlockIOSubSchema): string | undefined => {
|
||||
if (schema && "type" in schema && typeof schema.type === "string") {
|
||||
return schema.type;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const NodeHandle: FC<HandleProps> = ({
|
||||
keyName,
|
||||
schema,
|
||||
@@ -50,7 +57,8 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
title,
|
||||
className,
|
||||
}) => {
|
||||
const typeClass = `text-sm ${getTypeTextColor(schema.type || "any")} ${
|
||||
const schemaType = getSchemaType(schema);
|
||||
const typeClass = `text-sm ${getTypeTextColor(schemaType || "any")} ${
|
||||
side === "left" ? "text-left" : "text-right"
|
||||
}`;
|
||||
|
||||
@@ -66,7 +74,7 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
{isRequired ? "*" : ""}
|
||||
</span>
|
||||
<span className={`${typeClass} data-sentry-unmask flex items-end`}>
|
||||
({TYPE_NAME[schema.type as keyof typeof TYPE_NAME] || "any"})
|
||||
({TYPE_NAME[schemaType as keyof typeof TYPE_NAME] || "any"})
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
@@ -95,7 +103,7 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
className="group -ml-[38px]"
|
||||
>
|
||||
<div className="pointer-events-none flex items-center">
|
||||
<Dot isConnected={isConnected} type={schema.type} />
|
||||
<Dot isConnected={isConnected} type={schemaType} />
|
||||
{label}
|
||||
</div>
|
||||
</Handle>
|
||||
@@ -118,7 +126,7 @@ const NodeHandle: FC<HandleProps> = ({
|
||||
>
|
||||
<div className="pointer-events-none flex items-center">
|
||||
{label}
|
||||
<Dot isConnected={isConnected} type={schema.type} />
|
||||
<Dot isConnected={isConnected} type={schemaType} />
|
||||
</div>
|
||||
</Handle>
|
||||
</div>
|
||||
|
||||
@@ -610,8 +610,11 @@ const NodeOneOfDiscriminatorField: FC<{
|
||||
|
||||
return oneOfVariants
|
||||
.map((variant) => {
|
||||
const variantDiscValue = variant.properties?.[discriminatorProperty]
|
||||
?.const as string; // NOTE: can discriminators only be strings?
|
||||
const discProperty = variant.properties?.[discriminatorProperty];
|
||||
const variantDiscValue =
|
||||
discProperty && "const" in discProperty
|
||||
? (discProperty.const as string)
|
||||
: undefined; // NOTE: can discriminators only be strings?
|
||||
|
||||
return {
|
||||
value: variantDiscValue,
|
||||
|
||||
@@ -79,17 +79,43 @@ export default function useAgentGraph(
|
||||
|
||||
// Load available blocks & flows (stable - only loads once)
|
||||
useEffect(() => {
|
||||
api
|
||||
.getBlocks()
|
||||
.then((blocks) => {
|
||||
setAllBlocks(blocks);
|
||||
})
|
||||
.catch();
|
||||
const loadBlocks = () => {
|
||||
api
|
||||
.getBlocks()
|
||||
.then((blocks) => {
|
||||
setAllBlocks(blocks);
|
||||
})
|
||||
.catch();
|
||||
};
|
||||
|
||||
api
|
||||
.listGraphs()
|
||||
.then((flows) => setAvailableFlows(flows))
|
||||
.catch();
|
||||
const loadFlows = () => {
|
||||
api
|
||||
.listGraphs()
|
||||
.then((flows) => setAvailableFlows(flows))
|
||||
.catch();
|
||||
};
|
||||
|
||||
// Initial load
|
||||
loadBlocks();
|
||||
loadFlows();
|
||||
|
||||
// Listen for LLM registry refresh notifications to reload blocks
|
||||
const deregisterRegistryRefresh = api.onWebSocketMessage(
|
||||
"notification",
|
||||
(notification) => {
|
||||
if (
|
||||
notification?.type === "LLM_REGISTRY_REFRESH" ||
|
||||
notification?.event === "registry_updated"
|
||||
) {
|
||||
console.log("Received LLM registry refresh notification, reloading blocks...");
|
||||
loadBlocks();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
return () => {
|
||||
deregisterRegistryRefresh();
|
||||
};
|
||||
}, [api]);
|
||||
|
||||
// Subscribe to execution events
|
||||
|
||||
@@ -43,6 +43,7 @@ import type {
|
||||
LibraryAgentPresetUpdatable,
|
||||
LibraryAgentResponse,
|
||||
LibraryAgentSortEnum,
|
||||
LlmModel,
|
||||
MyAgentsResponse,
|
||||
NodeExecutionResult,
|
||||
NotificationPreference,
|
||||
@@ -52,6 +53,13 @@ import type {
|
||||
ProfileDetails,
|
||||
RefundRequest,
|
||||
ReviewSubmissionRequest,
|
||||
CreateLlmModelRequest,
|
||||
UpdateLlmModelRequest,
|
||||
ToggleLlmModelRequest,
|
||||
UpsertLlmProviderRequest,
|
||||
LlmModelsResponse,
|
||||
LlmProvider,
|
||||
LlmProvidersResponse,
|
||||
Schedule,
|
||||
ScheduleCreatable,
|
||||
ScheduleID,
|
||||
@@ -395,6 +403,74 @@ export default class BackendAPI {
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
/////////////// LLM MODELS /////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
listLlmModels(): Promise<LlmModelsResponse> {
|
||||
return this._get("/llm/models");
|
||||
}
|
||||
|
||||
listLlmProviders(includeModels = true): Promise<LlmProvidersResponse> {
|
||||
const query = includeModels ? { include_models: true } : undefined;
|
||||
return this._get("/llm/providers", query);
|
||||
}
|
||||
|
||||
listAdminLlmProviders(
|
||||
includeModels = true,
|
||||
): Promise<LlmProvidersResponse> {
|
||||
const query = includeModels ? { include_models: true } : undefined;
|
||||
return this._get("/llm/admin/llm/providers", query);
|
||||
}
|
||||
|
||||
createAdminLlmProvider(
|
||||
payload: UpsertLlmProviderRequest,
|
||||
): Promise<LlmProvider> {
|
||||
return this._request("POST", "/llm/admin/llm/providers", payload);
|
||||
}
|
||||
|
||||
updateAdminLlmProvider(
|
||||
providerId: string,
|
||||
payload: UpsertLlmProviderRequest,
|
||||
): Promise<LlmProvider> {
|
||||
return this._request(
|
||||
"PATCH",
|
||||
`/llm/admin/llm/providers/${providerId}`,
|
||||
payload,
|
||||
);
|
||||
}
|
||||
|
||||
listAdminLlmModels(providerId?: string): Promise<LlmModelsResponse> {
|
||||
const query = providerId ? { provider_id: providerId } : undefined;
|
||||
return this._get("/llm/admin/llm/models", query);
|
||||
}
|
||||
|
||||
createAdminLlmModel(payload: CreateLlmModelRequest): Promise<LlmModel> {
|
||||
return this._request("POST", "/llm/admin/llm/models", payload);
|
||||
}
|
||||
|
||||
updateAdminLlmModel(
|
||||
modelId: string,
|
||||
payload: UpdateLlmModelRequest,
|
||||
): Promise<LlmModel> {
|
||||
return this._request(
|
||||
"PATCH",
|
||||
`/llm/admin/llm/models/${modelId}`,
|
||||
payload,
|
||||
);
|
||||
}
|
||||
|
||||
toggleAdminLlmModel(
|
||||
modelId: string,
|
||||
payload: ToggleLlmModelRequest,
|
||||
): Promise<LlmModel> {
|
||||
return this._request(
|
||||
"PATCH",
|
||||
`/llm/admin/llm/models/${modelId}/toggle`,
|
||||
payload,
|
||||
);
|
||||
}
|
||||
|
||||
// API Key related requests
|
||||
async createAPIKey(
|
||||
name: string,
|
||||
|
||||
@@ -234,10 +234,8 @@ export type BlockIONullSubSchema = BlockIOSubSchemaMeta & {
|
||||
|
||||
// At the time of writing, combined schemas only occur on the first nested level in a
|
||||
// block schema. It is typed this way to make the use of these objects less tedious.
|
||||
type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta & {
|
||||
type: never;
|
||||
const: never;
|
||||
} & (
|
||||
type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta &
|
||||
(
|
||||
| {
|
||||
allOf: [BlockIOSimpleTypeSubSchema];
|
||||
secret?: boolean;
|
||||
@@ -252,6 +250,107 @@ type BlockIOCombinedTypeSubSchema = BlockIOSubSchemaMeta & {
|
||||
| BlockIODiscriminatedOneOfSubSchema
|
||||
);
|
||||
|
||||
////////////////////////////////////////
|
||||
///////////// LLM REGISTRY /////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
export type LlmCostUnit = "RUN" | "TOKENS";
|
||||
|
||||
export type LlmModelCostInput = {
|
||||
unit?: LlmCostUnit;
|
||||
credit_cost: number;
|
||||
credential_provider: string;
|
||||
credential_id?: string | null;
|
||||
credential_type?: string | null;
|
||||
currency?: string | null;
|
||||
metadata?: Record<string, any>;
|
||||
};
|
||||
|
||||
export type LlmModelCost = LlmModelCostInput & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type LlmModel = {
|
||||
id: string;
|
||||
slug: string;
|
||||
display_name: string;
|
||||
description?: string | null;
|
||||
provider_id: string;
|
||||
context_window: number;
|
||||
max_output_tokens?: number | null;
|
||||
is_enabled: boolean;
|
||||
capabilities: Record<string, any>;
|
||||
metadata: Record<string, any>;
|
||||
costs: LlmModelCost[];
|
||||
};
|
||||
|
||||
export type LlmProvider = {
|
||||
id: string;
|
||||
name: string;
|
||||
display_name: string;
|
||||
description?: string | null;
|
||||
default_credential_provider?: string | null;
|
||||
default_credential_id?: string | null;
|
||||
default_credential_type?: string | null;
|
||||
supports_tools: boolean;
|
||||
supports_json_output: boolean;
|
||||
supports_reasoning: boolean;
|
||||
supports_parallel_tool: boolean;
|
||||
metadata: Record<string, any>;
|
||||
models?: LlmModel[];
|
||||
};
|
||||
|
||||
export type LlmProvidersResponse = {
|
||||
providers: LlmProvider[];
|
||||
};
|
||||
|
||||
export type LlmModelsResponse = {
|
||||
models: LlmModel[];
|
||||
};
|
||||
|
||||
export type UpsertLlmProviderRequest = {
|
||||
name: string;
|
||||
display_name: string;
|
||||
description?: string | null;
|
||||
default_credential_provider?: string | null;
|
||||
default_credential_id?: string | null;
|
||||
default_credential_type?: string | null;
|
||||
supports_tools?: boolean;
|
||||
supports_json_output?: boolean;
|
||||
supports_reasoning?: boolean;
|
||||
supports_parallel_tool?: boolean;
|
||||
metadata?: Record<string, any>;
|
||||
};
|
||||
|
||||
export type CreateLlmModelRequest = {
|
||||
slug: string;
|
||||
display_name: string;
|
||||
description?: string | null;
|
||||
provider_id: string;
|
||||
context_window: number;
|
||||
max_output_tokens?: number | null;
|
||||
is_enabled?: boolean;
|
||||
capabilities?: Record<string, any>;
|
||||
metadata?: Record<string, any>;
|
||||
costs: LlmModelCostInput[];
|
||||
};
|
||||
|
||||
export type UpdateLlmModelRequest = {
|
||||
display_name?: string;
|
||||
description?: string | null;
|
||||
provider_id?: string;
|
||||
context_window?: number;
|
||||
max_output_tokens?: number | null;
|
||||
is_enabled?: boolean;
|
||||
capabilities?: Record<string, any>;
|
||||
metadata?: Record<string, any>;
|
||||
costs?: LlmModelCostInput[];
|
||||
};
|
||||
|
||||
export type ToggleLlmModelRequest = {
|
||||
is_enabled: boolean;
|
||||
};
|
||||
|
||||
export type BlockIOOneOfSubSchema = {
|
||||
oneOf: BlockIOSimpleTypeSubSchema[];
|
||||
default?: string | number | boolean | null;
|
||||
@@ -1135,16 +1234,21 @@ function _handleStringSchema(strSchema: BlockIOStringSubSchema): DataType {
|
||||
}
|
||||
|
||||
function _handleSingleTypeSchema(subSchema: BlockIOSubSchema): DataType {
|
||||
if (subSchema.type === "string") {
|
||||
const schemaType =
|
||||
"type" in subSchema && typeof subSchema.type === "string"
|
||||
? subSchema.type
|
||||
: undefined;
|
||||
|
||||
if (schemaType === "string") {
|
||||
return _handleStringSchema(subSchema as BlockIOStringSubSchema);
|
||||
}
|
||||
if (subSchema.type === "boolean") {
|
||||
if (schemaType === "boolean") {
|
||||
return DataType.BOOLEAN;
|
||||
}
|
||||
if (subSchema.type === "number" || subSchema.type === "integer") {
|
||||
if (schemaType === "number" || schemaType === "integer") {
|
||||
return DataType.NUMBER;
|
||||
}
|
||||
if (subSchema.type === "array") {
|
||||
if (schemaType === "array") {
|
||||
// Check for table format first
|
||||
if ("format" in subSchema && subSchema.format === "table") {
|
||||
return DataType.TABLE;
|
||||
@@ -1155,7 +1259,7 @@ function _handleSingleTypeSchema(subSchema: BlockIOSubSchema): DataType {
|
||||
// }
|
||||
return DataType.ARRAY;
|
||||
}
|
||||
if (subSchema.type === "object") {
|
||||
if (schemaType === "object") {
|
||||
if (
|
||||
("additionalProperties" in subSchema && subSchema.additionalProperties) ||
|
||||
!("properties" in subSchema)
|
||||
@@ -1164,7 +1268,7 @@ function _handleSingleTypeSchema(subSchema: BlockIOSubSchema): DataType {
|
||||
}
|
||||
if (
|
||||
Object.values(subSchema.properties).every(
|
||||
(prop) => prop.type === "boolean",
|
||||
(prop) => "type" in prop && prop.type === "boolean",
|
||||
)
|
||||
) {
|
||||
return DataType.MULTI_SELECT; // if all props are boolean => multi-select
|
||||
|
||||
@@ -234,17 +234,20 @@ export function fillObjectDefaultsFromSchema(
|
||||
// Apply simple default values
|
||||
obj[key] ??= propertySchema.default;
|
||||
} else if (
|
||||
"type" in propertySchema &&
|
||||
propertySchema.type === "object" &&
|
||||
"properties" in propertySchema
|
||||
) {
|
||||
// Recursively fill defaults for nested objects
|
||||
obj[key] = fillObjectDefaultsFromSchema(obj[key] ?? {}, propertySchema);
|
||||
} else if (propertySchema.type === "array") {
|
||||
} else if ("type" in propertySchema && propertySchema.type === "array") {
|
||||
obj[key] ??= [];
|
||||
// If the array items are objects, fill their defaults as well
|
||||
if (
|
||||
Array.isArray(obj[key]) &&
|
||||
propertySchema.items?.type === "object" &&
|
||||
propertySchema.items &&
|
||||
"type" in propertySchema.items &&
|
||||
propertySchema.items.type === "object" &&
|
||||
"properties" in propertySchema.items
|
||||
) {
|
||||
for (const item of obj[key]) {
|
||||
|
||||
Reference in New Issue
Block a user