Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
c814e9234f chore(deps): bump docker/login-action from 3 to 4
Bumps [docker/login-action](https://github.com/docker/login-action) from 3 to 4.
- [Release notes](https://github.com/docker/login-action/releases)
- [Commits](https://github.com/docker/login-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: docker/login-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-09 19:13:59 +00:00
35 changed files with 2110 additions and 10243 deletions

View File

@@ -107,7 +107,7 @@ jobs:
- if: github.event_name == 'push'
name: Log in to Docker hub
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PASSWORD }}

View File

@@ -23,7 +23,7 @@ jobs:
uses: actions/checkout@v4
- name: Log in to Docker hub
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PASSWORD }}

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url_host(auth_server_url)
await validate_url(auth_server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url_host(registration_endpoint)
await validate_url(registration_endpoint, trusted_origins=[])
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url_host is enforced on all endpoints."""
"""Verify that validate_url is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -37,10 +37,8 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.llm_registry
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
@@ -119,30 +117,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
# When block integration lands, this should fail hard or skip block initialization
try:
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry refreshed successfully at startup")
except Exception as e:
logger.warning(
f"Failed to refresh LLM registry at startup: {e}. "
"Blocks will initialize with empty registry."
)
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
try:
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
except Exception as e:
logger.warning(
f"Failed to migrate LLM models at startup: {e}. "
"This is expected in test environments without AgentNode table."
)
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
@@ -369,11 +348,6 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.server.v2.llm.router,
tags=["v2", "llm"],
prefix="/api",
)
app.mount("/external-api", external_api)

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url_host(input_data.url)
parsed_url, _, _ = await validate_url(input_data.url, [])
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -34,11 +34,8 @@ from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -808,11 +805,6 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]

View File

@@ -33,7 +33,7 @@ import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url_host
from backend.util.request import validate_url
from .base import BaseTool
from .models import (
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url_host(url)
await validate_url(url, trusted_origins=[])
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url_host(url)
await validate_url(url, trusted_origins=[])
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -68,18 +68,17 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
# ---------------------------------------------------------------------------
# SSRF protection via shared validate_url_host (backend.util.request)
# SSRF protection via shared validate_url (backend.util.request)
# ---------------------------------------------------------------------------
# Patch target: validate_url_host is imported directly into agent_browser's
# module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
# Patch target: validate_url is imported directly into agent_browser's module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
class TestSsrfViaValidateUrl:
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
"""Verify that browser_navigate uses validate_url for SSRF protection.
We mock validate_url_host itself (not the low-level socket) so these tests
We mock validate_url itself (not the low-level socket) so these tests
exercise the integration point, not the internals of request.py
(which has its own thorough test suite in request_test.py).
"""
@@ -90,7 +89,7 @@ class TestSsrfViaValidateUrl:
@pytest.mark.asyncio
async def test_blocked_ip_returns_blocked_url_error(self):
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.side_effect = ValueError(
"Access to blocked IP 10.0.0.1 is not allowed."
@@ -125,8 +124,8 @@ class TestSsrfViaValidateUrl:
assert result.error == "blocked_url"
@pytest.mark.asyncio
async def test_validate_url_host_called_without_trusted_hostnames(self):
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
async def test_validate_url_called_with_empty_trusted_origins(self):
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = (object(), False, ["1.2.3.4"])
with patch(
@@ -144,7 +143,7 @@ class TestSsrfViaValidateUrl:
session=self.session,
url="https://example.com",
)
mock_validate.assert_called_once_with("https://example.com")
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
# ---------------------------------------------------------------------------

View File

@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
from backend.util.request import HTTPClientError, validate_url_host
from backend.util.request import HTTPClientError, validate_url
from .base import BaseTool
from .models import (
@@ -144,7 +144,7 @@ class RunMCPToolTool(BaseTool):
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
try:
await validate_url_host(server_url)
await validate_url(server_url, trusted_origins=[])
except ValueError as e:
msg = str(e)
if "Unable to resolve" in msg or "No IP addresses" in msg:

View File

@@ -100,7 +100,7 @@ async def test_ssrf_blocked_url_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host",
"backend.copilot.tools.run_mcp_tool.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -138,7 +138,7 @@ async def test_non_dict_tool_arguments_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host",
"backend.copilot.tools.run_mcp_tool.validate_url",
new_callable=AsyncMock,
):
with patch(
@@ -171,7 +171,7 @@ async def test_discover_tools_returns_discovered_response():
mock_tools = _make_tool_list("fetch", "search")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -208,7 +208,7 @@ async def test_discover_tools_with_credentials():
mock_tools = _make_tool_list("push_notification")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -249,7 +249,7 @@ async def test_execute_tool_returns_output_response():
text_result = "# Example Domain\nThis domain is for examples."
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -285,7 +285,7 @@ async def test_execute_tool_parses_json_result():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -320,7 +320,7 @@ async def test_execute_tool_image_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -359,7 +359,7 @@ async def test_execute_tool_resource_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -399,7 +399,7 @@ async def test_execute_tool_multi_item_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -437,7 +437,7 @@ async def test_execute_tool_empty_content_returns_none():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -470,7 +470,7 @@ async def test_execute_tool_returns_error_on_tool_failure():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -512,7 +512,7 @@ async def test_auth_required_without_creds_returns_setup_requirements():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -555,7 +555,7 @@ async def test_auth_error_with_existing_creds_returns_error():
mock_creds.access_token = SecretStr("stale-token")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -589,7 +589,7 @@ async def test_mcp_client_error_returns_error_response():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -621,7 +621,7 @@ async def test_unexpected_exception_returns_generic_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -719,7 +719,7 @@ async def test_credential_lookup_normalizes_trailing_slash():
url_with_slash = "https://mcp.example.com/mcp/"
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",

View File

@@ -1,31 +0,0 @@
"""LLM Registry - Dynamic model management system."""
from .model import ModelMetadata
from .registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_all_models,
get_default_model_slug,
get_enabled_models,
get_model,
get_schema_options,
refresh_llm_registry,
)
__all__ = [
# Models
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Functions
"refresh_llm_registry",
"get_model",
"get_all_models",
"get_enabled_models",
"get_schema_options",
"get_default_model_slug",
"get_all_model_slugs_for_validation",
]

View File

@@ -1,9 +0,0 @@
"""Type definitions for LLM model metadata.
Re-exports ModelMetadata from blocks.llm to avoid type collision.
In PR #5 (block integration), this will become the canonical location.
"""
from backend.blocks.llm import ModelMetadata
__all__ = ["ModelMetadata"]

View File

@@ -1,240 +0,0 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
# In-memory cache (will be replaced with Redis in PR #6)
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_lock = asyncio.Lock()
async def refresh_llm_registry() -> None:
"""
Refresh the LLM registry from the database.
Fetches all models with their costs, providers, and creators,
then updates the in-memory cache.
"""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.info(f"Fetched {len(records)} LLM models from database")
# Build model instances
new_models: dict[str, RegistryModel] = {}
for record in records:
# Parse costs
costs = tuple(
RegistryModelCost(
unit=str(cost.unit), # Convert enum to string
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
# Parse creator
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
# Parse capabilities
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display = (
record.Provider.displayName
if record.Provider
else record.providerId
)
# Extract creator name (fallback to "Unknown" if no creator)
creator_name = (
record.Creator.displayName if record.Creator else "Unknown"
)
# Price tier defaults to 1 if not set
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
price_tier=price_tier,
)
# Create model instance
model = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
new_models[record.slug] = model
# Atomic swap
global _dynamic_models, _schema_options
_dynamic_models = new_models
_schema_options = _build_schema_options()
logger.info(
f"LLM registry refreshed: {len(_dynamic_models)} models, "
f"{len(_schema_options)} schema options"
)
except Exception as e:
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
raise
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
return [
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
for model in sorted(
_dynamic_models.values(), key=lambda m: m.display_name.lower()
)
if model.is_enabled
]
def get_model(slug: str) -> RegistryModel | None:
"""Get a model by slug from the registry."""
return _dynamic_models.get(slug)
def get_all_models() -> list[RegistryModel]:
"""Get all models from the registry (including disabled)."""
return list(_dynamic_models.values())
def get_enabled_models() -> list[RegistryModel]:
"""Get only enabled models from the registry."""
return [model for model in _dynamic_models.values() if model.is_enabled]
def get_schema_options() -> list[dict[str, str]]:
"""Get schema options for model selection dropdown (enabled models only)."""
return _schema_options
def get_default_model_slug() -> str | None:
"""Get the default model slug (first recommended, or first enabled)."""
# Sort once and use next() to short-circuit on first match
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
# Prefer recommended models
recommended = next(
(m.slug for m in models if m.is_recommended and m.is_enabled), None
)
if recommended:
return recommended
# Fallback to first enabled model
return next((m.slug for m in models if m.is_enabled), None)
def get_all_model_slugs_for_validation() -> list[str]:
"""
Get all model slugs for validation (enables migrate_llm_models to work).
Returns slugs for enabled models only.
"""
return [model.slug for model in _dynamic_models.values() if model.is_enabled]

View File

@@ -1,5 +0,0 @@
"""LLM registry public API."""
from .routes import router
__all__ = ["router"]

View File

@@ -1,67 +0,0 @@
"""Pydantic models for LLM registry public API."""
from __future__ import annotations
from typing import Any
import pydantic
class LlmModelCost(pydantic.BaseModel):
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int = pydantic.Field(ge=0)
credential_provider: str
credential_id: str | None = None
credential_type: str | None = None
currency: str | None = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model."""
id: str
name: str
display_name: str
description: str | None = None
website_url: str | None = None
logo_url: str | None = None
class LlmModel(pydantic.BaseModel):
"""Public-facing LLM model information."""
slug: str
display_name: str
description: str | None = None
provider_name: str
creator: LlmModelCreator | None = None
context_window: int
max_output_tokens: int | None = None
price_tier: int # 1=cheapest, 2=medium, 3=expensive
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
"""Provider with its enabled models."""
name: str
display_name: str
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmModelsResponse(pydantic.BaseModel):
"""Response for GET /llm/models."""
models: list[LlmModel]
total: int
class LlmProvidersResponse(pydantic.BaseModel):
"""Response for GET /llm/providers."""
providers: list[LlmProvider]

View File

@@ -1,141 +0,0 @@
"""Public read-only API for LLM registry."""
import autogpt_libs.auth
import fastapi
from backend.data.llm_registry import (
RegistryModelCreator,
get_all_models,
get_enabled_models,
)
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
def _map_creator(
creator: RegistryModelCreator | None,
) -> llm_model.LlmModelCreator | None:
"""Convert registry creator to API model."""
if not creator:
return None
return llm_model.LlmModelCreator(
id=creator.id,
name=creator.name,
display_name=creator.display_name,
description=creator.description,
website_url=creator.website_url,
logo_url=creator.logo_url,
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models(
enabled_only: bool = fastapi.Query(
default=True, description="Only return enabled models"
),
):
"""
List all LLM models available to users.
Returns models from the in-memory registry cache.
Use enabled_only=true to filter to only enabled models (default).
"""
# Get models from in-memory registry
registry_models = get_enabled_models() if enabled_only else get_all_models()
# Map to API response models
models = [
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in registry_models
]
return llm_model.LlmModelsResponse(models=models, total=len(models))
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""
List all LLM providers with their enabled models.
Groups enabled models by provider from the in-memory registry.
"""
# Get all enabled models and group by provider
registry_models = get_enabled_models()
# Group models by provider
provider_map: dict[str, list] = {}
for model in registry_models:
provider_key = model.metadata.provider
if provider_key not in provider_map:
provider_map[provider_key] = []
provider_map[provider_key].append(model)
# Build provider responses
providers = []
for provider_key, models in sorted(provider_map.items()):
# Use the first model's provider display name
display_name = models[0].provider_display_name if models else provider_key
providers.append(
llm_model.LlmProvider(
name=provider_key,
display_name=display_name,
models=[
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in sorted(models, key=lambda m: m.display_name)
],
)
)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -144,106 +144,76 @@ async def _resolve_host(hostname: str) -> list[str]:
return ip_addresses
async def validate_url_host(
url: str, trusted_hostnames: Optional[list[str]] = None
async def validate_url(
url: str, trusted_origins: list[str]
) -> tuple[URL, bool, list[str]]:
"""
Validates a (URL's) host string to prevent SSRF attacks by ensuring it does not
point to a private, link-local, or otherwise blocked IP address — unless
Validates the URL to prevent SSRF attacks by ensuring it does not point
to a private, link-local, or otherwise blocked IP address — unless
the hostname is explicitly trusted.
Hosts in `trusted_hostnames` are permitted without checks.
All other hosts are resolved and checked against `BLOCKED_IP_NETWORKS`.
Params:
url: A hostname, netloc, or URL to validate.
If no scheme is included, `http://` is assumed.
trusted_hostnames: A list of hostnames that don't require validation.
Raises:
ValueError:
- if the URL has a disallowed URL scheme
- if the URL/host string can't be parsed
- if the hostname contains invalid or unsupported (non-ASCII) characters
- if the host resolves to a blocked IP
Returns:
1. The validated, canonicalized, parsed host/URL,
with hostname ASCII-safe encoding
2. Whether the host is trusted (based on the passed `trusted_hostnames`).
3. List of resolved IP addresses for the host; empty if the host is trusted.
str: The validated, canonicalized, parsed URL
is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted
"""
parsed = parse_url(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(
f"URL scheme '{parsed.scheme}' is not allowed; allowed schemes: "
f"{', '.join(ALLOWED_SCHEMES)}"
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
)
# Validate and IDNA encode hostname
if not parsed.hostname:
raise ValueError(f"Invalid host/URL; no host in parse result: {url}")
raise ValueError("Invalid URL: No hostname found.")
# IDNA encode to prevent Unicode domain attacks
try:
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
except idna.IDNAError:
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
raise ValueError("Invalid hostname with unsupported characters.")
# Check hostname characters
if not HOSTNAME_REGEX.match(ascii_hostname):
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
raise ValueError("Hostname contains invalid characters.")
# Re-create parsed URL object with IDNA-encoded hostname
parsed = URL(
parsed.scheme,
(ascii_hostname if parsed.port is None else f"{ascii_hostname}:{parsed.port}"),
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
# Check if hostname is trusted
is_trusted = ascii_hostname in trusted_origins
# If not trusted, validate IP addresses
ip_addresses: list[str] = []
if not is_trusted:
# Resolve all IP addresses for the hostname
ip_addresses = await _resolve_host(ascii_hostname)
# Block any IP address that belongs to a blocked range
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {ascii_hostname} is not allowed."
)
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
netloc = ascii_hostname
if parsed.port:
netloc = f"{ascii_hostname}:{parsed.port}"
return (
URL(
parsed.scheme,
netloc,
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
),
is_trusted,
ip_addresses,
)
is_trusted = trusted_hostnames and any(
matches_allowed_host(parsed, allowed)
for allowed in (
# Normalize + parse allowlist entries the same way for consistent matching
parse_url(w)
for w in trusted_hostnames
)
)
if is_trusted:
return parsed, True, []
# If not allowlisted, go ahead with host resolution and IP target check
return parsed, False, await _resolve_and_check_blocked(ascii_hostname)
def matches_allowed_host(url: URL, allowed: URL) -> bool:
if url.hostname != allowed.hostname:
return False
# Allow any port if not explicitly specified in the allowlist
if allowed.port is None:
return True
return url.port == allowed.port
async def _resolve_and_check_blocked(hostname: str) -> list[str]:
"""
Resolves hostname to IPs and raises ValueError if any resolve to
a blocked network. Returns the list of resolved IP addresses.
"""
ip_addresses = await _resolve_host(hostname)
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {hostname} is not allowed."
)
return ip_addresses
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
@@ -382,7 +352,7 @@ class Requests:
):
self.trusted_origins = []
for url in trusted_origins or []:
hostname = parse_url(url).netloc # {host}[:{port}]
hostname = urlparse(url).hostname
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname of {url}")
self.trusted_origins.append(hostname)
@@ -480,7 +450,7 @@ class Requests:
data = form
# Validate URL and get trust status
parsed_url, is_trusted, ip_addresses = await validate_url_host(
parsed_url, is_trusted, ip_addresses = await validate_url(
url, self.trusted_origins
)
@@ -533,6 +503,7 @@ class Requests:
json=json,
**kwargs,
) as response:
if self.raise_for_status:
try:
response.raise_for_status()

View File

@@ -1,7 +1,7 @@
import pytest
from aiohttp import web
from backend.util.request import pin_url, validate_url_host
from backend.util.request import pin_url, validate_url
@pytest.mark.parametrize(
@@ -60,9 +60,9 @@ async def test_validate_url_no_dns_rebinding(
):
if should_raise:
with pytest.raises(ValueError):
await validate_url_host(raw_url, trusted_origins)
await validate_url(raw_url, trusted_origins)
else:
validated_url, _, _ = await validate_url_host(raw_url, trusted_origins)
validated_url, _, _ = await validate_url(raw_url, trusted_origins)
assert validated_url.geturl() == expected_value
@@ -101,10 +101,10 @@ async def test_dns_rebinding_fix(
if expect_error:
# If any IP is blocked, we expect a ValueError
with pytest.raises(ValueError):
url, _, ip_addresses = await validate_url_host(hostname)
url, _, ip_addresses = await validate_url(hostname, [])
pin_url(url, ip_addresses)
else:
url, _, ip_addresses = await validate_url_host(hostname)
url, _, ip_addresses = await validate_url(hostname, [])
pinned_url = pin_url(url, ip_addresses).geturl()
# The pinned_url should contain the first valid IP
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")

View File

@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=500,
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
)
ollama_host: str = Field(
default="localhost:11434",
description="Default Ollama host; exempted from SSRF checks.",
)
pyro_host: str = Field(
default="localhost",
description="The default hostname of the Pyro server.",

View File

@@ -1,148 +0,0 @@
-- CreateEnum
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
-- CreateTable
CREATE TABLE "LlmProvider" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"defaultCredentialProvider" TEXT,
"defaultCredentialId" TEXT,
"defaultCredentialType" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelCreator" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"websiteUrl" TEXT,
"logoUrl" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModel" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"slug" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"providerId" TEXT NOT NULL,
"creatorId" TEXT,
"contextWindow" INTEGER NOT NULL,
"maxOutputTokens" INTEGER,
"priceTier" INTEGER NOT NULL DEFAULT 1,
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
"capabilities" JSONB NOT NULL DEFAULT '{}',
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelCost" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
"creditCost" INTEGER NOT NULL,
"credentialProvider" TEXT NOT NULL,
"credentialId" TEXT,
"credentialType" TEXT,
"currency" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
"llmModelId" TEXT NOT NULL,
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelMigration" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"sourceModelSlug" TEXT NOT NULL,
"targetModelSlug" TEXT NOT NULL,
"reason" TEXT,
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
"nodeCount" INTEGER NOT NULL,
"customCreditCost" INTEGER,
"isReverted" BOOLEAN NOT NULL DEFAULT false,
"revertedAt" TIMESTAMP(3),
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
-- CreateIndex
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
-- CreateIndex
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- CreateIndex (partial unique for default costs - no specific credential)
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
-- CreateIndex (partial unique for credential-specific costs)
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
-- CreateIndex
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
-- CreateIndex (partial unique to prevent multiple active migrations per source)
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddCheckConstraints (enforce data integrity)
ALTER TABLE "LlmModel"
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
ALTER TABLE "LlmModelCost"
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
ALTER TABLE "LlmModelMigration"
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);

View File

@@ -1304,164 +1304,3 @@ model OAuthRefreshToken {
@@index([userId, applicationId])
@@index([expiresAt]) // For cleanup
}
// ============================================================================
// LLM Registry Models
// ============================================================================
enum LlmCostUnit {
RUN
TOKENS
}
model LlmProvider {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique
displayName String
description String?
defaultCredentialProvider String?
defaultCredentialId String?
defaultCredentialType String?
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModel {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
slug String @unique
displayName String
description String?
providerId String
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
creatorId String?
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
contextWindow Int
maxOutputTokens Int?
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
isEnabled Boolean @default(true)
isRecommended Boolean @default(false)
// Model-specific capabilities
// These vary per model even within the same provider (e.g., Hugging Face)
// Default to false for safety - partially-seeded rows should not be assumed capable
supportsTools Boolean @default(false)
supportsJsonOutput Boolean @default(false)
supportsReasoning Boolean @default(false)
supportsParallelToolCalls Boolean @default(false)
capabilities Json @default("{}")
metadata Json @default("{}")
Costs LlmModelCost[]
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
@@index([providerId, isEnabled])
@@index([creatorId])
// Note: slug already has @unique which creates an implicit index
}
model LlmModelCost {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
unit LlmCostUnit @default(RUN)
creditCost Int // DB constraint: >= 0
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
// Used to determine which credential system provides the API key.
// Allows different pricing for:
// - Default provider costs (WHERE credentialId IS NULL)
// - User's own API key costs (WHERE credentialId IS NOT NULL)
credentialProvider String
credentialId String?
credentialType String?
currency String?
metadata Json @default("{}")
llmModelId String
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
// Note: Unique constraints are implemented as partial indexes in migration SQL:
// - One for default costs (WHERE credentialId IS NULL)
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
// This allows both provider-level defaults and credential-specific overrides
}
model LlmModelCreator {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique // e.g., "openai", "anthropic", "meta"
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
description String?
websiteUrl String? // Link to creator's website
logoUrl String? // URL to creator's logo
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModelMigration {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
sourceModelSlug String // The original model that was disabled
targetModelSlug String // The model workflows were migrated to
reason String? // Why the migration happened (e.g., "Provider outage")
// FK constraints ensure slugs reference valid models
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict)
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict)
// Track affected nodes as JSON array of node IDs
// Format: ["node-uuid-1", "node-uuid-2", ...]
migratedNodeIds Json @default("[]")
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
// Custom pricing override for migrated workflows during the migration period.
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
// to avoid billing surprises, or offer a discount during the transition.
//
// IMPORTANT: This field is intended for integration with the billing system.
// When billing calculates costs for nodes affected by this migration, it should
// check if customCreditCost is set and use it instead of the target model's cost.
// If null, the target model's normal cost applies.
//
// TODO: Integrate with billing system to apply this override during cost calculation.
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
// For token-priced models, this may be ambiguous. Consider migrating to a relation
// with LlmModelCost or a dedicated override model in a follow-up PR.
customCreditCost Int? // DB constraint: >= 0 when not null
// Revert tracking
isReverted Boolean @default(false)
revertedAt DateTime?
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
// UNIQUE (sourceModelSlug) WHERE isReverted = false
@@index([targetModelSlug])
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
}

View File

@@ -0,0 +1,36 @@
"use server";
import BackendAPI from "@/lib/autogpt-server-api/client";
import { OttoQuery, OttoResponse } from "@/lib/autogpt-server-api/types";
const api = new BackendAPI();
export async function askOtto(
query: string,
conversationHistory: { query: string; response: string }[],
includeGraphData: boolean,
graphId?: string,
): Promise<OttoResponse> {
const messageId = `${Date.now()}-web`;
const ottoQuery: OttoQuery = {
query,
conversation_history: conversationHistory,
message_id: messageId,
include_graph_data: includeGraphData,
graph_id: graphId,
};
try {
const response = await api.askOtto(ottoQuery);
return response;
} catch (error) {
console.error("Error in askOtto server action:", error);
return {
answer: error instanceof Error ? error.message : "Unknown error occurred",
documents: [],
success: false,
error: true,
};
}
}

View File

@@ -23,12 +23,6 @@ import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
import { useCustomNode } from "./useCustomNode";
function hasAdvancedFields(schema: RJSFSchema): boolean {
const properties = schema?.properties;
if (!properties) return false;
return Object.values(properties).some((prop: any) => prop.advanced === true);
}
export type CustomNodeData = {
hardcodedValues: {
[key: string]: any;
@@ -114,11 +108,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
)}
showHandles={showHandles}
/>
<NodeAdvancedToggle
nodeId={nodeId}
isLastSection={data.uiType === BlockUIType.OUTPUT}
hasAdvancedFields={hasAdvancedFields(inputSchema)}
/>
<NodeAdvancedToggle nodeId={nodeId} />
{data.uiType != BlockUIType.OUTPUT && (
<OutputHandler
uiType={data.uiType}

View File

@@ -2,33 +2,18 @@ import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { CaretDownIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
type Props = {
nodeId: string;
isLastSection?: boolean;
hasAdvancedFields?: boolean;
};
export function NodeAdvancedToggle({
nodeId,
isLastSection,
hasAdvancedFields = true,
}: Props) {
export function NodeAdvancedToggle({ nodeId }: Props) {
const showAdvanced = useNodeStore(
(state) => state.nodeAdvancedStates[nodeId] || false,
);
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
if (!hasAdvancedFields) return null;
return (
<div
className={cn(
"flex items-center justify-start gap-2 bg-white px-5 pb-3.5",
isLastSection && "rounded-b-xlarge",
)}
>
<div className="flex items-center justify-start gap-2 bg-white px-5 pb-3.5">
<Button
variant="ghost"
className="h-fit min-w-0 p-0 hover:border-transparent hover:bg-transparent"

View File

@@ -1,6 +1,6 @@
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { CaretDownIcon, CaretRightIcon, InfoIcon } from "@phosphor-icons/react";
import { CaretDownIcon, InfoIcon } from "@phosphor-icons/react";
import { RJSFSchema } from "@rjsf/utils";
import { useState } from "react";
@@ -30,41 +30,13 @@ export const OutputHandler = ({
const properties = outputSchema?.properties || {};
const [isOutputVisible, setIsOutputVisible] = useState(true);
const brokenOutputs = useBrokenOutputs(nodeId);
const [expandedObjects, setExpandedObjects] = useState<
Record<string, boolean>
>({});
const showHandles = uiType !== BlockUIType.OUTPUT;
function toggleObjectExpanded(key: string) {
setExpandedObjects((prev) => ({ ...prev, [key]: !prev[key] }));
}
function hasConnectedOrBrokenDescendant(
schema: RJSFSchema,
keyPrefix: string,
): boolean {
if (!schema) return false;
return Object.entries(schema).some(
([key, fieldSchema]: [string, RJSFSchema]) => {
const fullKey = keyPrefix ? `${keyPrefix}_#_${key}` : key;
if (isOutputConnected(nodeId, fullKey) || brokenOutputs.has(fullKey))
return true;
if (fieldSchema?.properties)
return hasConnectedOrBrokenDescendant(
fieldSchema.properties,
fullKey,
);
return false;
},
);
}
const renderOutputHandles = (
schema: RJSFSchema,
keyPrefix: string = "",
titlePrefix: string = "",
connectedOnly: boolean = false,
): React.ReactNode[] => {
return Object.entries(schema).map(
([key, fieldSchema]: [string, RJSFSchema]) => {
@@ -72,23 +44,10 @@ export const OutputHandler = ({
const fieldTitle = titlePrefix + (fieldSchema?.title || key);
const isConnected = isOutputConnected(nodeId, fullKey);
const isBroken = brokenOutputs.has(fullKey);
const hasNestedProperties = !!fieldSchema?.properties;
const selfIsRelevant = isConnected || isBroken;
const descendantIsRelevant =
hasNestedProperties &&
hasConnectedOrBrokenDescendant(fieldSchema.properties!, fullKey);
const shouldShow = connectedOnly
? selfIsRelevant || descendantIsRelevant
: isOutputVisible || selfIsRelevant || descendantIsRelevant;
const shouldShow = isConnected || isOutputVisible;
const { displayType, colorClass, hexColor } =
getTypeDisplayInfo(fieldSchema);
const isExpanded = expandedObjects[fullKey] ?? false;
// User expanded → show all children; auto-expanded → filter to connected only
const shouldRenderChildren = isExpanded || descendantIsRelevant;
const isBroken = brokenOutputs.has(fullKey);
return shouldShow ? (
<div
@@ -97,19 +56,6 @@ export const OutputHandler = ({
data-tutorial-id={`output-handler-${nodeId}-${fieldTitle}`}
>
<div className="relative flex items-center gap-2">
{hasNestedProperties && (
<button
onClick={() => toggleObjectExpanded(fullKey)}
className="flex items-center text-slate-500 hover:text-slate-700"
aria-label={isExpanded ? "Collapse" : "Expand"}
>
{isExpanded ? (
<CaretDownIcon size={12} weight="bold" />
) : (
<CaretRightIcon size={12} weight="bold" />
)}
</button>
)}
{fieldSchema?.description && (
<TooltipProvider>
<Tooltip>
@@ -156,14 +102,12 @@ export const OutputHandler = ({
)}
</div>
{/* Nested properties */}
{hasNestedProperties &&
shouldRenderChildren &&
{/* Recursively render nested properties */}
{fieldSchema?.properties &&
renderOutputHandles(
fieldSchema.properties!,
fieldSchema.properties,
fullKey,
"",
!isExpanded,
`${fieldTitle}.`,
)}
</div>
) : null;
@@ -192,7 +136,7 @@ export const OutputHandler = ({
</Button>
<div className="flex flex-col items-end gap-2">
{renderOutputHandles(properties, "", "", !isOutputVisible)}
{renderOutputHandles(properties)}
</div>
</div>
);

View File

@@ -25,60 +25,34 @@ type HistoryStore = {
const MAX_HISTORY = 50;
// Microtask batching state — kept outside the store to avoid triggering
// re-renders. When multiple pushState calls happen in the same synchronous
// execution (e.g. node deletion cascading to edge cleanup), only the first
// (pre-change) state is kept and committed as a single history entry.
let pendingState: HistoryState | null = null;
let batchScheduled = false;
export const useHistoryStore = create<HistoryStore>((set, get) => ({
past: [{ nodes: [], edges: [] }],
future: [],
pushState: (state: HistoryState) => {
// Keep only the first state within a microtask batch — it represents
// the true pre-change snapshot before any cascading mutations.
if (!pendingState) {
pendingState = state;
const { past } = get();
const lastState = past[past.length - 1];
if (lastState && isEqual(lastState, state)) {
return;
}
if (!batchScheduled) {
batchScheduled = true;
queueMicrotask(() => {
const stateToCommit = pendingState;
pendingState = null;
batchScheduled = false;
const actualCurrentState = {
nodes: useNodeStore.getState().nodes,
edges: useEdgeStore.getState().edges,
};
if (!stateToCommit) return;
const { past } = get();
const lastState = past[past.length - 1];
if (lastState && isEqual(lastState, stateToCommit)) {
return;
}
const actualCurrentState = {
nodes: useNodeStore.getState().nodes,
edges: useEdgeStore.getState().edges,
};
if (isEqual(stateToCommit, actualCurrentState)) {
return;
}
set((prev) => ({
past: [...prev.past.slice(-MAX_HISTORY + 1), stateToCommit],
future: [],
}));
});
if (isEqual(state, actualCurrentState)) {
return;
}
set((prev) => ({
past: [...prev.past.slice(-MAX_HISTORY + 1), state],
future: [],
}));
},
initializeHistory: () => {
pendingState = null;
const currentNodes = useNodeStore.getState().nodes;
const currentEdges = useEdgeStore.getState().edges;
@@ -148,8 +122,5 @@ export const useHistoryStore = create<HistoryStore>((set, get) => ({
},
canRedo: () => get().future.length > 0,
clear: () => {
pendingState = null;
set({ past: [{ nodes: [], edges: [] }], future: [] });
},
clear: () => set({ past: [{ nodes: [], edges: [] }], future: [] }),
}));

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,10 @@ export function FormRenderer({
}, [preprocessedSchema, uiSchema]);
return (
<div className={cn("mt-4", className)} data-tutorial-id="input-handles">
<div
className={cn("mb-6 mt-4", className)}
data-tutorial-id="input-handles"
>
<Form
formContext={formContext}
idPrefix="agpt"

View File

@@ -63,6 +63,7 @@ export const useAnyOfField = (props: FieldProps) => {
);
const handlePrefix = cleanUpHandleId(field_id);
console.log("handlePrefix", handlePrefix);
useEdgeStore
.getState()
.removeEdgesByHandlePrefix(registry.formContext.nodeId, handlePrefix);

View File

@@ -4,7 +4,6 @@ import {
TemplatesType,
} from "@rjsf/utils";
import { AnyOfField } from "./anyof/AnyOfField";
import { OneOfField } from "./oneof/OneOfField";
import {
ArrayFieldItemTemplate,
ArrayFieldTemplate,
@@ -33,7 +32,6 @@ const NoButton = () => null;
export function generateBaseFields(): RegistryFieldsType {
return {
AnyOfField,
OneOfField,
ArraySchemaField,
};
}

View File

@@ -1,243 +0,0 @@
import {
descriptionId,
FieldProps,
getTemplate,
getUiOptions,
getWidget,
} from "@rjsf/utils";
import { useEffect, useRef, useState } from "react";
import { AnyOfField } from "../anyof/AnyOfField";
import { cleanUpHandleId, getHandleId, updateUiOption } from "../../helpers";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { ANY_OF_FLAG } from "../../constants";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
function getDiscriminatorPropName(schema: any): string | undefined {
if (!schema?.discriminator) return undefined;
if (typeof schema.discriminator === "string") return schema.discriminator;
return schema.discriminator.propertyName;
}
export function OneOfField(props: FieldProps) {
const { schema } = props;
const discriminatorProp = getDiscriminatorPropName(schema);
if (!discriminatorProp) {
return <AnyOfField {...props} />;
}
return (
<DiscriminatedUnionField {...props} discriminatorProp={discriminatorProp} />
);
}
interface DiscriminatedUnionFieldProps extends FieldProps {
discriminatorProp: string;
}
function DiscriminatedUnionField({
discriminatorProp,
...props
}: DiscriminatedUnionFieldProps) {
const { schema, registry, formData, onChange, name } = props;
const { fields, schemaUtils, formContext } = registry;
const { SchemaField } = fields;
const { nodeId } = formContext;
const field_id = props.fieldPathId.$id;
// Resolve variant schemas from $refs
const variants = useRef(
(schema.oneOf || []).map((opt: any) =>
schemaUtils.retrieveSchema(opt, formData),
),
);
// Build dropdown options from variant titles and discriminator const values
const enumOptions = variants.current.map((variant: any, index: number) => {
const discValue = (variant.properties?.[discriminatorProp] as any)?.const;
return {
value: index,
label: variant.title || discValue || `Option ${index + 1}`,
discriminatorValue: discValue,
};
});
// Determine initial selected index from formData
function getInitialIndex() {
const currentDisc = formData?.[discriminatorProp];
if (currentDisc) {
const idx = enumOptions.findIndex(
(o) => o.discriminatorValue === currentDisc,
);
if (idx >= 0) return idx;
}
return 0;
}
const [selectedIndex, setSelectedIndex] = useState(getInitialIndex);
// Generate handleId for sub-fields (same convention as AnyOfField)
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
const handleId = getHandleId({
uiOptions,
id: field_id + ANY_OF_FLAG,
schema,
});
const childUiSchema = updateUiOption(props.uiSchema, {
handleId,
label: false,
fromAnyOf: true,
});
// Get selected variant schema with discriminator property filtered out
// and sub-fields inheriting the parent's advanced value
const selectedVariant = variants.current[selectedIndex];
const parentAdvanced = (schema as any).advanced;
function getFilteredSchema() {
if (!selectedVariant?.properties) return selectedVariant;
const filteredProperties: Record<string, any> = {};
for (const [key, value] of Object.entries(selectedVariant.properties)) {
if (key === discriminatorProp) continue;
filteredProperties[key] =
parentAdvanced !== undefined
? { ...(value as any), advanced: parentAdvanced }
: value;
}
return {
...selectedVariant,
properties: filteredProperties,
required: (selectedVariant.required || []).filter(
(r: string) => r !== discriminatorProp,
),
};
}
const filteredSchema = getFilteredSchema();
// Handle variant change
function handleVariantChange(option?: string) {
const newIndex = option !== undefined ? parseInt(option, 10) : -1;
if (newIndex === selectedIndex || newIndex < 0) return;
const newVariant = variants.current[newIndex];
const oldVariant = variants.current[selectedIndex];
const discValue = (newVariant.properties?.[discriminatorProp] as any)
?.const;
// Clean edges for this field
const handlePrefix = cleanUpHandleId(field_id);
useEdgeStore.getState().removeEdgesByHandlePrefix(nodeId, handlePrefix);
// Sanitize current data against old→new schema to preserve shared fields
let newFormData = schemaUtils.sanitizeDataForNewSchema(
newVariant,
oldVariant,
formData,
);
// Fill in defaults for the new variant
newFormData = schemaUtils.getDefaultFormState(
newVariant,
newFormData,
"excludeObjectChildren",
) as any;
newFormData = { ...newFormData, [discriminatorProp]: discValue };
setSelectedIndex(newIndex);
onChange(newFormData, props.fieldPathId.path, undefined, field_id);
}
// Sync selectedIndex when formData discriminator changes externally
// (e.g. undo/redo, loading saved state)
const currentDiscValue = formData?.[discriminatorProp];
useEffect(() => {
const idx = currentDiscValue
? enumOptions.findIndex((o) => o.discriminatorValue === currentDiscValue)
: -1;
if (idx >= 0) {
if (idx !== selectedIndex) setSelectedIndex(idx);
} else if (enumOptions.length > 0 && selectedIndex !== 0) {
// Unknown or cleared discriminator — full reset via same cleanup path
handleVariantChange("0");
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentDiscValue]);
// Auto-set discriminator on initial render if missing
useEffect(() => {
const discValue = enumOptions[selectedIndex]?.discriminatorValue;
if (discValue && formData?.[discriminatorProp] !== discValue) {
onChange(
{ ...formData, [discriminatorProp]: discValue },
props.fieldPathId.path,
undefined,
field_id,
);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
const selector = (
<Widget
id={field_id}
name={`${name}__oneof_select`}
schema={{ type: "number", default: 0 }}
onChange={handleVariantChange}
onBlur={props.onBlur}
onFocus={props.onFocus}
disabled={props.disabled || enumOptions.length === 0}
multiple={false}
value={selectedIndex}
options={{ enumOptions }}
registry={registry}
placeholder={props.placeholder}
autocomplete={props.autocomplete}
className={cn("-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium")}
autofocus={props.autofocus}
label=""
hideLabel={true}
readonly={props.readonly}
/>
);
const DescriptionFieldTemplate = getTemplate(
"DescriptionFieldTemplate",
registry,
uiOptions,
);
const description_id = descriptionId(props.fieldPathId ?? "");
return (
<div>
<div className="flex items-center gap-2">
<Text variant="body" className="line-clamp-1">
{schema.title || name}
</Text>
<Text variant="small" className="mr-1 text-red-500">
{props.required ? "*" : null}
</Text>
{selector}
<DescriptionFieldTemplate
id={description_id}
description={schema.description || ""}
schema={schema}
registry={registry}
/>
</div>
{filteredSchema && filteredSchema.type !== "null" && (
<SchemaField
{...props}
schema={filteredSchema}
uiSchema={childUiSchema}
/>
)}
</div>
);
}

View File

@@ -6,11 +6,7 @@ import {
titleId,
} from "@rjsf/utils";
import {
isAnyOfChild,
isAnyOfSchema,
isOneOfSchema,
} from "../../utils/schema-utils";
import { isAnyOfChild, isAnyOfSchema } from "../../utils/schema-utils";
import {
cleanUpHandleId,
getHandleId,
@@ -86,13 +82,12 @@ export default function FieldTemplate(props: FieldTemplateProps) {
const shouldDisplayLabel =
displayLabel ||
(schema.type === "boolean" && !isAnyOfChild(uiSchema as any));
const isUnionSchema = isAnyOfSchema(schema) || isOneOfSchema(schema);
const shouldShowTitleSection = !isUnionSchema && !additional;
const shouldShowTitleSection = !isAnyOfSchema(schema) && !additional;
const shouldShowChildren =
schema.type === "object" ||
schema.type === "array" ||
isUnionSchema ||
isAnyOfSchema(schema) ||
!isHandleConnected;
const isAdvancedField = (schema as any).advanced === true;
@@ -100,7 +95,8 @@ export default function FieldTemplate(props: FieldTemplateProps) {
return null;
}
const marginBottom = isPartOfAnyOf({ uiOptions }) || isUnionSchema ? 0 : 16;
const marginBottom =
isPartOfAnyOf({ uiOptions }) || isAnyOfSchema(schema) ? 0 : 16;
return (
<WrapIfAdditionalTemplate

View File

@@ -7,7 +7,7 @@ import {
import { Text } from "@/components/atoms/Text/Text";
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
import { isAnyOfSchema, isOneOfSchema } from "../../utils/schema-utils";
import { isAnyOfSchema } from "../../utils/schema-utils";
import { cn } from "@/lib/utils";
import { cleanUpHandleId, isArrayItem } from "../../helpers";
import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle";
@@ -18,7 +18,7 @@ export default function TitleField(props: TitleFieldProps) {
const { nodeId, showHandles } = registry.formContext;
const uiOptions = getUiOptions(uiSchema);
const isAnyOf = isAnyOfSchema(schema) || isOneOfSchema(schema);
const isAnyOf = isAnyOfSchema(schema);
const { displayType, colorClass } = getTypeDisplayInfo(schema);
const description_id = descriptionId(id);

View File

@@ -8,14 +8,6 @@ export function isAnyOfSchema(schema: RJSFSchema | undefined): boolean {
);
}
export function isOneOfSchema(schema: RJSFSchema | undefined): boolean {
return (
Array.isArray(schema?.oneOf) &&
schema!.oneOf.length > 0 &&
schema?.enum === undefined
);
}
export const isAnyOfChild = (
uiSchema: UiSchema<any, RJSFSchema, any> | undefined,
): boolean => {