mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
27 Commits
dependabot
...
feat/llm-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
957ec038b8 | ||
|
|
9b39a662ee | ||
|
|
9b93a956b4 | ||
|
|
b236719bbf | ||
|
|
4f286f510f | ||
|
|
b1595d871d | ||
|
|
29ab7f2d9c | ||
|
|
784936b323 | ||
|
|
f2ae38a1a7 | ||
|
|
2ccfb4e4c1 | ||
|
|
5641cdd3ca | ||
|
|
c65e5c957a | ||
|
|
54355a691b | ||
|
|
3cafa49c4c | ||
|
|
ded002a406 | ||
|
|
4fdf89c3be | ||
|
|
d816bd739f | ||
|
|
bfb843a56e | ||
|
|
6a16376323 | ||
|
|
ed7b02ffb1 | ||
|
|
d064198dd1 | ||
|
|
01ad033b2b | ||
|
|
56bcbda054 | ||
|
|
684845d946 | ||
|
|
d40efc6056 | ||
|
|
6a6b23c2e1 | ||
|
|
d0a1d72e8a |
@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +80,7 @@ async def discover_tools(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
await validate_url_host(auth_server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
|
||||
if registration_endpoint:
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
await validate_url_host(registration_endpoint)
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
@@ -429,7 +429,7 @@ async def mcp_store_token(
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ async def client():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_ssrf_validation():
|
||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
||||
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
@@ -521,12 +521,12 @@ class TestStoreToken:
|
||||
|
||||
|
||||
class TestSSRFValidation:
|
||||
"""Verify that validate_url is enforced on all endpoints."""
|
||||
"""Verify that validate_url_host is enforced on all endpoints."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
@@ -541,7 +541,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked private IP"),
|
||||
):
|
||||
@@ -556,7 +556,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
|
||||
@@ -37,8 +37,10 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.llm_registry
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.library.exceptions import (
|
||||
@@ -117,11 +119,30 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
|
||||
# When block integration lands, this should fail hard or skip block initialization
|
||||
try:
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry refreshed successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to refresh LLM registry at startup: {e}. "
|
||||
"Blocks will initialize with empty registry."
|
||||
)
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
try:
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate LLM models at startup: {e}. "
|
||||
"This is expected in test environments without AgentNode table."
|
||||
)
|
||||
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
@@ -348,6 +369,11 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
|
||||
from backend.blocks.search import GetRequest
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
try:
|
||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
||||
parsed_url, _, _ = await validate_url_host(input_data.url)
|
||||
url = parsed_url.geturl()
|
||||
except ValueError as e:
|
||||
yield "error", f"Invalid URL: {e}"
|
||||
|
||||
@@ -34,8 +34,11 @@ from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
@@ -805,6 +808,11 @@ async def llm_call(
|
||||
if tools:
|
||||
raise ValueError("Ollama does not support tools.")
|
||||
|
||||
# Validate user-provided Ollama host to prevent SSRF etc.
|
||||
await validate_url_host(
|
||||
ollama_host, trusted_hostnames=[settings.config.ollama_host]
|
||||
)
|
||||
|
||||
client = ollama.AsyncClient(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
|
||||
@@ -33,7 +33,7 @@ import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -235,7 +235,7 @@ async def _restore_browser_state(
|
||||
if url:
|
||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
await validate_url_host(url)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
await validate_url_host(url)
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
|
||||
@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection via shared validate_url (backend.util.request)
|
||||
# SSRF protection via shared validate_url_host (backend.util.request)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Patch target: validate_url is imported directly into agent_browser's module scope.
|
||||
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
|
||||
# Patch target: validate_url_host is imported directly into agent_browser's
|
||||
# module scope.
|
||||
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
|
||||
|
||||
|
||||
class TestSsrfViaValidateUrl:
|
||||
"""Verify that browser_navigate uses validate_url for SSRF protection.
|
||||
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
|
||||
|
||||
We mock validate_url itself (not the low-level socket) so these tests
|
||||
We mock validate_url_host itself (not the low-level socket) so these tests
|
||||
exercise the integration point, not the internals of request.py
|
||||
(which has its own thorough test suite in request_test.py).
|
||||
"""
|
||||
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_ip_returns_blocked_url_error(self):
|
||||
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
|
||||
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
|
||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.side_effect = ValueError(
|
||||
"Access to blocked IP 10.0.0.1 is not allowed."
|
||||
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
|
||||
assert result.error == "blocked_url"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_url_called_with_empty_trusted_origins(self):
|
||||
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
|
||||
async def test_validate_url_host_called_without_trusted_hostnames(self):
|
||||
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
|
||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.return_value = (object(), False, ["1.2.3.4"])
|
||||
with patch(
|
||||
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
|
||||
session=self.session,
|
||||
url="https://example.com",
|
||||
)
|
||||
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
|
||||
mock_validate.assert_called_once_with("https://example.com")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
|
||||
from backend.util.request import HTTPClientError, validate_url
|
||||
from backend.util.request import HTTPClientError, validate_url_host
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -144,7 +144,7 @@ class RunMCPToolTool(BaseTool):
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
|
||||
try:
|
||||
await validate_url(server_url, trusted_origins=[])
|
||||
await validate_url_host(server_url)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "Unable to resolve" in msg or "No IP addresses" in msg:
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_ssrf_blocked_url_returns_error():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
@@ -138,7 +138,7 @@ async def test_non_dict_tool_arguments_returns_error():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch(
|
||||
@@ -171,7 +171,7 @@ async def test_discover_tools_returns_discovered_response():
|
||||
mock_tools = _make_tool_list("fetch", "search")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -208,7 +208,7 @@ async def test_discover_tools_with_credentials():
|
||||
mock_tools = _make_tool_list("push_notification")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -249,7 +249,7 @@ async def test_execute_tool_returns_output_response():
|
||||
text_result = "# Example Domain\nThis domain is for examples."
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -285,7 +285,7 @@ async def test_execute_tool_parses_json_result():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -320,7 +320,7 @@ async def test_execute_tool_image_content():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -359,7 +359,7 @@ async def test_execute_tool_resource_content():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -399,7 +399,7 @@ async def test_execute_tool_multi_item_content():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -437,7 +437,7 @@ async def test_execute_tool_empty_content_returns_none():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -470,7 +470,7 @@ async def test_execute_tool_returns_error_on_tool_failure():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -512,7 +512,7 @@ async def test_auth_required_without_creds_returns_setup_requirements():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -555,7 +555,7 @@ async def test_auth_error_with_existing_creds_returns_error():
|
||||
mock_creds.access_token = SecretStr("stale-token")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -589,7 +589,7 @@ async def test_mcp_client_error_returns_error_response():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -621,7 +621,7 @@ async def test_unexpected_exception_returns_generic_error():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -719,7 +719,7 @@ async def test_credential_lookup_normalizes_trailing_slash():
|
||||
url_with_slash = "https://mcp.example.com/mcp/"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from .model import ModelMetadata
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
"get_enabled_models",
|
||||
"get_schema_options",
|
||||
"get_default_model_slug",
|
||||
"get_all_model_slugs_for_validation",
|
||||
]
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Type definitions for LLM model metadata.
|
||||
|
||||
Re-exports ModelMetadata from blocks.llm to avoid type collision.
|
||||
In PR #5 (block integration), this will become the canonical location.
|
||||
"""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
__all__ = ["ModelMetadata"]
|
||||
240
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
240
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCreator:
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
website_url: str | None
|
||||
logo_url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
|
||||
# In-memory cache (will be replaced with Redis in PR #6)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""
|
||||
Refresh the LLM registry from the database.
|
||||
|
||||
Fetches all models with their costs, providers, and creators,
|
||||
then updates the in-memory cache.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.info(f"Fetched {len(records)} LLM models from database")
|
||||
|
||||
# Build model instances
|
||||
new_models: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
# Parse costs
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit), # Convert enum to string
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Parse creator
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
# Build metadata from record
|
||||
# Warn if Provider relation is missing (indicates data corruption)
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
|
||||
f"falling back to providerId {record.providerId}"
|
||||
)
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display = (
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
)
|
||||
|
||||
# Extract creator name (fallback to "Unknown" if no creator)
|
||||
creator_name = (
|
||||
record.Creator.displayName if record.Creator else "Unknown"
|
||||
)
|
||||
|
||||
# Price tier defaults to 1 if not set
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
# Create model instance
|
||||
model = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
new_models[record.slug] = model
|
||||
|
||||
# Atomic swap
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
f"LLM registry refreshed: {len(_dynamic_models)} models, "
|
||||
f"{len(_schema_options)} schema options"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
return [
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
for model in sorted(
|
||||
_dynamic_models.values(), key=lambda m: m.display_name.lower()
|
||||
)
|
||||
if model.is_enabled
|
||||
]
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
"""Get a model by slug from the registry."""
|
||||
return _dynamic_models.get(slug)
|
||||
|
||||
|
||||
def get_all_models() -> list[RegistryModel]:
|
||||
"""Get all models from the registry (including disabled)."""
|
||||
return list(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_enabled_models() -> list[RegistryModel]:
|
||||
"""Get only enabled models from the registry."""
|
||||
return [model for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return _schema_options
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
# Sort once and use next() to short-circuit on first match
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
|
||||
# Prefer recommended models
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
if recommended:
|
||||
return recommended
|
||||
|
||||
# Fallback to first enabled model
|
||||
return next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""
|
||||
Get all model slugs for validation (enables migrate_llm_models to work).
|
||||
Returns slugs for enabled models only.
|
||||
"""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
@@ -0,0 +1,5 @@
|
||||
"""LLM registry public API."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
67
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
67
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Pydantic models for LLM registry public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int = pydantic.Field(ge=0)
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
"""Public-facing LLM model information."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator: LlmModelCreator | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int # 1=cheapest, 2=medium, 3=expensive
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
"""Provider with its enabled models."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/models."""
|
||||
|
||||
models: list[LlmModel]
|
||||
total: int
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/providers."""
|
||||
|
||||
providers: list[LlmProvider]
|
||||
141
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
141
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Public read-only API for LLM registry."""
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data.llm_registry import (
|
||||
RegistryModelCreator,
|
||||
get_all_models,
|
||||
get_enabled_models,
|
||||
)
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
creator: RegistryModelCreator | None,
|
||||
) -> llm_model.LlmModelCreator | None:
|
||||
"""Convert registry creator to API model."""
|
||||
if not creator:
|
||||
return None
|
||||
return llm_model.LlmModelCreator(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.display_name,
|
||||
description=creator.description,
|
||||
website_url=creator.website_url,
|
||||
logo_url=creator.logo_url,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
enabled_only: bool = fastapi.Query(
|
||||
default=True, description="Only return enabled models"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all LLM models available to users.
|
||||
|
||||
Returns models from the in-memory registry cache.
|
||||
Use enabled_only=true to filter to only enabled models (default).
|
||||
"""
|
||||
# Get models from in-memory registry
|
||||
registry_models = get_enabled_models() if enabled_only else get_all_models()
|
||||
|
||||
# Map to API response models
|
||||
models = [
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in registry_models
|
||||
]
|
||||
|
||||
return llm_model.LlmModelsResponse(models=models, total=len(models))
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
"""
|
||||
List all LLM providers with their enabled models.
|
||||
|
||||
Groups enabled models by provider from the in-memory registry.
|
||||
"""
|
||||
# Get all enabled models and group by provider
|
||||
registry_models = get_enabled_models()
|
||||
|
||||
# Group models by provider
|
||||
provider_map: dict[str, list] = {}
|
||||
for model in registry_models:
|
||||
provider_key = model.metadata.provider
|
||||
if provider_key not in provider_map:
|
||||
provider_map[provider_key] = []
|
||||
provider_map[provider_key].append(model)
|
||||
|
||||
# Build provider responses
|
||||
providers = []
|
||||
for provider_key, models in sorted(provider_map.items()):
|
||||
# Use the first model's provider display name
|
||||
display_name = models[0].provider_display_name if models else provider_key
|
||||
|
||||
providers.append(
|
||||
llm_model.LlmProvider(
|
||||
name=provider_key,
|
||||
display_name=display_name,
|
||||
models=[
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in sorted(models, key=lambda m: m.display_name)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
@@ -144,76 +144,106 @@ async def _resolve_host(hostname: str) -> list[str]:
|
||||
return ip_addresses
|
||||
|
||||
|
||||
async def validate_url(
|
||||
url: str, trusted_origins: list[str]
|
||||
async def validate_url_host(
|
||||
url: str, trusted_hostnames: Optional[list[str]] = None
|
||||
) -> tuple[URL, bool, list[str]]:
|
||||
"""
|
||||
Validates the URL to prevent SSRF attacks by ensuring it does not point
|
||||
to a private, link-local, or otherwise blocked IP address — unless
|
||||
Validates a (URL's) host string to prevent SSRF attacks by ensuring it does not
|
||||
point to a private, link-local, or otherwise blocked IP address — unless
|
||||
the hostname is explicitly trusted.
|
||||
|
||||
Hosts in `trusted_hostnames` are permitted without checks.
|
||||
All other hosts are resolved and checked against `BLOCKED_IP_NETWORKS`.
|
||||
|
||||
Params:
|
||||
url: A hostname, netloc, or URL to validate.
|
||||
If no scheme is included, `http://` is assumed.
|
||||
trusted_hostnames: A list of hostnames that don't require validation.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
- if the URL has a disallowed URL scheme
|
||||
- if the URL/host string can't be parsed
|
||||
- if the hostname contains invalid or unsupported (non-ASCII) characters
|
||||
- if the host resolves to a blocked IP
|
||||
|
||||
Returns:
|
||||
str: The validated, canonicalized, parsed URL
|
||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||
1. The validated, canonicalized, parsed host/URL,
|
||||
with hostname ASCII-safe encoding
|
||||
2. Whether the host is trusted (based on the passed `trusted_hostnames`).
|
||||
3. List of resolved IP addresses for the host; empty if the host is trusted.
|
||||
"""
|
||||
parsed = parse_url(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
|
||||
f"URL scheme '{parsed.scheme}' is not allowed; allowed schemes: "
|
||||
f"{', '.join(ALLOWED_SCHEMES)}"
|
||||
)
|
||||
|
||||
# Validate and IDNA encode hostname
|
||||
if not parsed.hostname:
|
||||
raise ValueError("Invalid URL: No hostname found.")
|
||||
raise ValueError(f"Invalid host/URL; no host in parse result: {url}")
|
||||
|
||||
# IDNA encode to prevent Unicode domain attacks
|
||||
try:
|
||||
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
|
||||
except idna.IDNAError:
|
||||
raise ValueError("Invalid hostname with unsupported characters.")
|
||||
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
|
||||
|
||||
# Check hostname characters
|
||||
if not HOSTNAME_REGEX.match(ascii_hostname):
|
||||
raise ValueError("Hostname contains invalid characters.")
|
||||
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
|
||||
|
||||
# Check if hostname is trusted
|
||||
is_trusted = ascii_hostname in trusted_origins
|
||||
|
||||
# If not trusted, validate IP addresses
|
||||
ip_addresses: list[str] = []
|
||||
if not is_trusted:
|
||||
# Resolve all IP addresses for the hostname
|
||||
ip_addresses = await _resolve_host(ascii_hostname)
|
||||
|
||||
# Block any IP address that belongs to a blocked range
|
||||
for ip_str in ip_addresses:
|
||||
if _is_ip_blocked(ip_str):
|
||||
raise ValueError(
|
||||
f"Access to blocked or private IP address {ip_str} "
|
||||
f"for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
|
||||
netloc = ascii_hostname
|
||||
if parsed.port:
|
||||
netloc = f"{ascii_hostname}:{parsed.port}"
|
||||
|
||||
return (
|
||||
URL(
|
||||
parsed.scheme,
|
||||
netloc,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
),
|
||||
is_trusted,
|
||||
ip_addresses,
|
||||
# Re-create parsed URL object with IDNA-encoded hostname
|
||||
parsed = URL(
|
||||
parsed.scheme,
|
||||
(ascii_hostname if parsed.port is None else f"{ascii_hostname}:{parsed.port}"),
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
|
||||
is_trusted = trusted_hostnames and any(
|
||||
matches_allowed_host(parsed, allowed)
|
||||
for allowed in (
|
||||
# Normalize + parse allowlist entries the same way for consistent matching
|
||||
parse_url(w)
|
||||
for w in trusted_hostnames
|
||||
)
|
||||
)
|
||||
|
||||
if is_trusted:
|
||||
return parsed, True, []
|
||||
|
||||
# If not allowlisted, go ahead with host resolution and IP target check
|
||||
return parsed, False, await _resolve_and_check_blocked(ascii_hostname)
|
||||
|
||||
|
||||
def matches_allowed_host(url: URL, allowed: URL) -> bool:
|
||||
if url.hostname != allowed.hostname:
|
||||
return False
|
||||
|
||||
# Allow any port if not explicitly specified in the allowlist
|
||||
if allowed.port is None:
|
||||
return True
|
||||
|
||||
return url.port == allowed.port
|
||||
|
||||
|
||||
async def _resolve_and_check_blocked(hostname: str) -> list[str]:
|
||||
"""
|
||||
Resolves hostname to IPs and raises ValueError if any resolve to
|
||||
a blocked network. Returns the list of resolved IP addresses.
|
||||
"""
|
||||
ip_addresses = await _resolve_host(hostname)
|
||||
for ip_str in ip_addresses:
|
||||
if _is_ip_blocked(ip_str):
|
||||
raise ValueError(
|
||||
f"Access to blocked or private IP address {ip_str} "
|
||||
f"for hostname {hostname} is not allowed."
|
||||
)
|
||||
return ip_addresses
|
||||
|
||||
|
||||
def parse_url(url: str) -> URL:
|
||||
"""Canonicalizes and parses a URL string."""
|
||||
@@ -352,7 +382,7 @@ class Requests:
|
||||
):
|
||||
self.trusted_origins = []
|
||||
for url in trusted_origins or []:
|
||||
hostname = urlparse(url).hostname
|
||||
hostname = parse_url(url).netloc # {host}[:{port}]
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname of {url}")
|
||||
self.trusted_origins.append(hostname)
|
||||
@@ -450,7 +480,7 @@ class Requests:
|
||||
data = form
|
||||
|
||||
# Validate URL and get trust status
|
||||
parsed_url, is_trusted, ip_addresses = await validate_url(
|
||||
parsed_url, is_trusted, ip_addresses = await validate_url_host(
|
||||
url, self.trusted_origins
|
||||
)
|
||||
|
||||
@@ -503,7 +533,6 @@ class Requests:
|
||||
json=json,
|
||||
**kwargs,
|
||||
) as response:
|
||||
|
||||
if self.raise_for_status:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from backend.util.request import pin_url, validate_url
|
||||
from backend.util.request import pin_url, validate_url_host
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -60,9 +60,9 @@ async def test_validate_url_no_dns_rebinding(
|
||||
):
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
await validate_url(raw_url, trusted_origins)
|
||||
await validate_url_host(raw_url, trusted_origins)
|
||||
else:
|
||||
validated_url, _, _ = await validate_url(raw_url, trusted_origins)
|
||||
validated_url, _, _ = await validate_url_host(raw_url, trusted_origins)
|
||||
assert validated_url.geturl() == expected_value
|
||||
|
||||
|
||||
@@ -101,10 +101,10 @@ async def test_dns_rebinding_fix(
|
||||
if expect_error:
|
||||
# If any IP is blocked, we expect a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
url, _, ip_addresses = await validate_url(hostname, [])
|
||||
url, _, ip_addresses = await validate_url_host(hostname)
|
||||
pin_url(url, ip_addresses)
|
||||
else:
|
||||
url, _, ip_addresses = await validate_url(hostname, [])
|
||||
url, _, ip_addresses = await validate_url_host(hostname)
|
||||
pinned_url = pin_url(url, ip_addresses).geturl()
|
||||
# The pinned_url should contain the first valid IP
|
||||
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")
|
||||
|
||||
@@ -89,6 +89,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
ollama_host: str = Field(
|
||||
default="localhost:11434",
|
||||
description="Default Ollama host; exempted from SSRF checks.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"creatorId" TEXT,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"priceTier" INTEGER NOT NULL DEFAULT 1,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- CreateIndex (partial unique for default costs - no specific credential)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
|
||||
|
||||
-- CreateIndex (partial unique for credential-specific costs)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
|
||||
|
||||
-- CreateIndex (partial unique to prevent multiple active migrations per source)
|
||||
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddCheckConstraints (enforce data integrity)
|
||||
ALTER TABLE "LlmModel"
|
||||
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
|
||||
|
||||
ALTER TABLE "LlmModelCost"
|
||||
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
|
||||
|
||||
ALTER TABLE "LlmModelMigration"
|
||||
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
|
||||
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);
|
||||
@@ -1304,3 +1304,164 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLM Registry Models
|
||||
// ============================================================================
|
||||
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
// Model-specific capabilities
|
||||
// These vary per model even within the same provider (e.g., Hugging Face)
|
||||
// Default to false for safety - partially-seeded rows should not be assumed capable
|
||||
supportsTools Boolean @default(false)
|
||||
supportsJsonOutput Boolean @default(false)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelToolCalls Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
|
||||
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
// Note: slug already has @unique which creates an implicit index
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int // DB constraint: >= 0
|
||||
|
||||
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
|
||||
// Used to determine which credential system provides the API key.
|
||||
// Allows different pricing for:
|
||||
// - Default provider costs (WHERE credentialId IS NULL)
|
||||
// - User's own API key costs (WHERE credentialId IS NOT NULL)
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Note: Unique constraints are implemented as partial indexes in migration SQL:
|
||||
// - One for default costs (WHERE credentialId IS NULL)
|
||||
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
|
||||
// This allows both provider-level defaults and credential-specific overrides
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// FK constraints ensure slugs reference valid models
|
||||
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict)
|
||||
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict)
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
|
||||
// For token-priced models, this may be ambiguous. Consider migrating to a relation
|
||||
// with LlmModelCost or a dedicated override model in a follow-up PR.
|
||||
customCreditCost Int? // DB constraint: >= 0 when not null
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
|
||||
// UNIQUE (sourceModelSlug) WHERE isReverted = false
|
||||
@@index([targetModelSlug])
|
||||
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
|
||||
}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"use server";
|
||||
|
||||
import BackendAPI from "@/lib/autogpt-server-api/client";
|
||||
import { OttoQuery, OttoResponse } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
const api = new BackendAPI();
|
||||
|
||||
export async function askOtto(
|
||||
query: string,
|
||||
conversationHistory: { query: string; response: string }[],
|
||||
includeGraphData: boolean,
|
||||
graphId?: string,
|
||||
): Promise<OttoResponse> {
|
||||
const messageId = `${Date.now()}-web`;
|
||||
|
||||
const ottoQuery: OttoQuery = {
|
||||
query,
|
||||
conversation_history: conversationHistory,
|
||||
message_id: messageId,
|
||||
include_graph_data: includeGraphData,
|
||||
graph_id: graphId,
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await api.askOtto(ottoQuery);
|
||||
return response;
|
||||
} catch (error) {
|
||||
console.error("Error in askOtto server action:", error);
|
||||
return {
|
||||
answer: error instanceof Error ? error.message : "Unknown error occurred",
|
||||
documents: [],
|
||||
success: false,
|
||||
error: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,12 @@ import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
|
||||
import { useCustomNode } from "./useCustomNode";
|
||||
|
||||
function hasAdvancedFields(schema: RJSFSchema): boolean {
|
||||
const properties = schema?.properties;
|
||||
if (!properties) return false;
|
||||
return Object.values(properties).some((prop: any) => prop.advanced === true);
|
||||
}
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
[key: string]: any;
|
||||
@@ -108,7 +114,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
)}
|
||||
showHandles={showHandles}
|
||||
/>
|
||||
<NodeAdvancedToggle nodeId={nodeId} />
|
||||
<NodeAdvancedToggle
|
||||
nodeId={nodeId}
|
||||
isLastSection={data.uiType === BlockUIType.OUTPUT}
|
||||
hasAdvancedFields={hasAdvancedFields(inputSchema)}
|
||||
/>
|
||||
{data.uiType != BlockUIType.OUTPUT && (
|
||||
<OutputHandler
|
||||
uiType={data.uiType}
|
||||
|
||||
@@ -2,18 +2,33 @@ import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CaretDownIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
isLastSection?: boolean;
|
||||
hasAdvancedFields?: boolean;
|
||||
};
|
||||
|
||||
export function NodeAdvancedToggle({ nodeId }: Props) {
|
||||
export function NodeAdvancedToggle({
|
||||
nodeId,
|
||||
isLastSection,
|
||||
hasAdvancedFields = true,
|
||||
}: Props) {
|
||||
const showAdvanced = useNodeStore(
|
||||
(state) => state.nodeAdvancedStates[nodeId] || false,
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
|
||||
if (!hasAdvancedFields) return null;
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-start gap-2 bg-white px-5 pb-3.5">
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-start gap-2 bg-white px-5 pb-3.5",
|
||||
isLastSection && "rounded-b-xlarge",
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="h-fit min-w-0 p-0 hover:border-transparent hover:bg-transparent"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CaretDownIcon, InfoIcon } from "@phosphor-icons/react";
|
||||
import { CaretDownIcon, CaretRightIcon, InfoIcon } from "@phosphor-icons/react";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -30,13 +30,41 @@ export const OutputHandler = ({
|
||||
const properties = outputSchema?.properties || {};
|
||||
const [isOutputVisible, setIsOutputVisible] = useState(true);
|
||||
const brokenOutputs = useBrokenOutputs(nodeId);
|
||||
const [expandedObjects, setExpandedObjects] = useState<
|
||||
Record<string, boolean>
|
||||
>({});
|
||||
|
||||
const showHandles = uiType !== BlockUIType.OUTPUT;
|
||||
|
||||
function toggleObjectExpanded(key: string) {
|
||||
setExpandedObjects((prev) => ({ ...prev, [key]: !prev[key] }));
|
||||
}
|
||||
|
||||
function hasConnectedOrBrokenDescendant(
|
||||
schema: RJSFSchema,
|
||||
keyPrefix: string,
|
||||
): boolean {
|
||||
if (!schema) return false;
|
||||
return Object.entries(schema).some(
|
||||
([key, fieldSchema]: [string, RJSFSchema]) => {
|
||||
const fullKey = keyPrefix ? `${keyPrefix}_#_${key}` : key;
|
||||
if (isOutputConnected(nodeId, fullKey) || brokenOutputs.has(fullKey))
|
||||
return true;
|
||||
if (fieldSchema?.properties)
|
||||
return hasConnectedOrBrokenDescendant(
|
||||
fieldSchema.properties,
|
||||
fullKey,
|
||||
);
|
||||
return false;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
const renderOutputHandles = (
|
||||
schema: RJSFSchema,
|
||||
keyPrefix: string = "",
|
||||
titlePrefix: string = "",
|
||||
connectedOnly: boolean = false,
|
||||
): React.ReactNode[] => {
|
||||
return Object.entries(schema).map(
|
||||
([key, fieldSchema]: [string, RJSFSchema]) => {
|
||||
@@ -44,10 +72,23 @@ export const OutputHandler = ({
|
||||
const fieldTitle = titlePrefix + (fieldSchema?.title || key);
|
||||
|
||||
const isConnected = isOutputConnected(nodeId, fullKey);
|
||||
const shouldShow = isConnected || isOutputVisible;
|
||||
const isBroken = brokenOutputs.has(fullKey);
|
||||
const hasNestedProperties = !!fieldSchema?.properties;
|
||||
const selfIsRelevant = isConnected || isBroken;
|
||||
const descendantIsRelevant =
|
||||
hasNestedProperties &&
|
||||
hasConnectedOrBrokenDescendant(fieldSchema.properties!, fullKey);
|
||||
|
||||
const shouldShow = connectedOnly
|
||||
? selfIsRelevant || descendantIsRelevant
|
||||
: isOutputVisible || selfIsRelevant || descendantIsRelevant;
|
||||
|
||||
const { displayType, colorClass, hexColor } =
|
||||
getTypeDisplayInfo(fieldSchema);
|
||||
const isBroken = brokenOutputs.has(fullKey);
|
||||
const isExpanded = expandedObjects[fullKey] ?? false;
|
||||
|
||||
// User expanded → show all children; auto-expanded → filter to connected only
|
||||
const shouldRenderChildren = isExpanded || descendantIsRelevant;
|
||||
|
||||
return shouldShow ? (
|
||||
<div
|
||||
@@ -56,6 +97,19 @@ export const OutputHandler = ({
|
||||
data-tutorial-id={`output-handler-${nodeId}-${fieldTitle}`}
|
||||
>
|
||||
<div className="relative flex items-center gap-2">
|
||||
{hasNestedProperties && (
|
||||
<button
|
||||
onClick={() => toggleObjectExpanded(fullKey)}
|
||||
className="flex items-center text-slate-500 hover:text-slate-700"
|
||||
aria-label={isExpanded ? "Collapse" : "Expand"}
|
||||
>
|
||||
{isExpanded ? (
|
||||
<CaretDownIcon size={12} weight="bold" />
|
||||
) : (
|
||||
<CaretRightIcon size={12} weight="bold" />
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
{fieldSchema?.description && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
@@ -102,12 +156,14 @@ export const OutputHandler = ({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Recursively render nested properties */}
|
||||
{fieldSchema?.properties &&
|
||||
{/* Nested properties */}
|
||||
{hasNestedProperties &&
|
||||
shouldRenderChildren &&
|
||||
renderOutputHandles(
|
||||
fieldSchema.properties,
|
||||
fieldSchema.properties!,
|
||||
fullKey,
|
||||
`${fieldTitle}.`,
|
||||
"",
|
||||
!isExpanded,
|
||||
)}
|
||||
</div>
|
||||
) : null;
|
||||
@@ -136,7 +192,7 @@ export const OutputHandler = ({
|
||||
</Button>
|
||||
|
||||
<div className="flex flex-col items-end gap-2">
|
||||
{renderOutputHandles(properties)}
|
||||
{renderOutputHandles(properties, "", "", !isOutputVisible)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -25,34 +25,60 @@ type HistoryStore = {
|
||||
|
||||
const MAX_HISTORY = 50;
|
||||
|
||||
// Microtask batching state — kept outside the store to avoid triggering
|
||||
// re-renders. When multiple pushState calls happen in the same synchronous
|
||||
// execution (e.g. node deletion cascading to edge cleanup), only the first
|
||||
// (pre-change) state is kept and committed as a single history entry.
|
||||
let pendingState: HistoryState | null = null;
|
||||
let batchScheduled = false;
|
||||
|
||||
export const useHistoryStore = create<HistoryStore>((set, get) => ({
|
||||
past: [{ nodes: [], edges: [] }],
|
||||
future: [],
|
||||
|
||||
pushState: (state: HistoryState) => {
|
||||
const { past } = get();
|
||||
const lastState = past[past.length - 1];
|
||||
|
||||
if (lastState && isEqual(lastState, state)) {
|
||||
return;
|
||||
// Keep only the first state within a microtask batch — it represents
|
||||
// the true pre-change snapshot before any cascading mutations.
|
||||
if (!pendingState) {
|
||||
pendingState = state;
|
||||
}
|
||||
|
||||
const actualCurrentState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
if (!batchScheduled) {
|
||||
batchScheduled = true;
|
||||
queueMicrotask(() => {
|
||||
const stateToCommit = pendingState;
|
||||
pendingState = null;
|
||||
batchScheduled = false;
|
||||
|
||||
if (isEqual(state, actualCurrentState)) {
|
||||
return;
|
||||
if (!stateToCommit) return;
|
||||
|
||||
const { past } = get();
|
||||
const lastState = past[past.length - 1];
|
||||
|
||||
if (lastState && isEqual(lastState, stateToCommit)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const actualCurrentState = {
|
||||
nodes: useNodeStore.getState().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
if (isEqual(stateToCommit, actualCurrentState)) {
|
||||
return;
|
||||
}
|
||||
|
||||
set((prev) => ({
|
||||
past: [...prev.past.slice(-MAX_HISTORY + 1), stateToCommit],
|
||||
future: [],
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
set((prev) => ({
|
||||
past: [...prev.past.slice(-MAX_HISTORY + 1), state],
|
||||
future: [],
|
||||
}));
|
||||
},
|
||||
|
||||
initializeHistory: () => {
|
||||
pendingState = null;
|
||||
|
||||
const currentNodes = useNodeStore.getState().nodes;
|
||||
const currentEdges = useEdgeStore.getState().edges;
|
||||
|
||||
@@ -122,5 +148,8 @@ export const useHistoryStore = create<HistoryStore>((set, get) => ({
|
||||
},
|
||||
canRedo: () => get().future.length > 0,
|
||||
|
||||
clear: () => set({ past: [{ nodes: [], edges: [] }], future: [] }),
|
||||
clear: () => {
|
||||
pendingState = null;
|
||||
set({ past: [{ nodes: [], edges: [] }], future: [] });
|
||||
},
|
||||
}));
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,10 +34,7 @@ export function FormRenderer({
|
||||
}, [preprocessedSchema, uiSchema]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn("mb-6 mt-4", className)}
|
||||
data-tutorial-id="input-handles"
|
||||
>
|
||||
<div className={cn("mt-4", className)} data-tutorial-id="input-handles">
|
||||
<Form
|
||||
formContext={formContext}
|
||||
idPrefix="agpt"
|
||||
|
||||
@@ -63,7 +63,6 @@ export const useAnyOfField = (props: FieldProps) => {
|
||||
);
|
||||
|
||||
const handlePrefix = cleanUpHandleId(field_id);
|
||||
console.log("handlePrefix", handlePrefix);
|
||||
useEdgeStore
|
||||
.getState()
|
||||
.removeEdgesByHandlePrefix(registry.formContext.nodeId, handlePrefix);
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
TemplatesType,
|
||||
} from "@rjsf/utils";
|
||||
import { AnyOfField } from "./anyof/AnyOfField";
|
||||
import { OneOfField } from "./oneof/OneOfField";
|
||||
import {
|
||||
ArrayFieldItemTemplate,
|
||||
ArrayFieldTemplate,
|
||||
@@ -32,6 +33,7 @@ const NoButton = () => null;
|
||||
export function generateBaseFields(): RegistryFieldsType {
|
||||
return {
|
||||
AnyOfField,
|
||||
OneOfField,
|
||||
ArraySchemaField,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
import {
|
||||
descriptionId,
|
||||
FieldProps,
|
||||
getTemplate,
|
||||
getUiOptions,
|
||||
getWidget,
|
||||
} from "@rjsf/utils";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { AnyOfField } from "../anyof/AnyOfField";
|
||||
import { cleanUpHandleId, getHandleId, updateUiOption } from "../../helpers";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { ANY_OF_FLAG } from "../../constants";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
function getDiscriminatorPropName(schema: any): string | undefined {
|
||||
if (!schema?.discriminator) return undefined;
|
||||
if (typeof schema.discriminator === "string") return schema.discriminator;
|
||||
return schema.discriminator.propertyName;
|
||||
}
|
||||
|
||||
export function OneOfField(props: FieldProps) {
|
||||
const { schema } = props;
|
||||
|
||||
const discriminatorProp = getDiscriminatorPropName(schema);
|
||||
if (!discriminatorProp) {
|
||||
return <AnyOfField {...props} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<DiscriminatedUnionField {...props} discriminatorProp={discriminatorProp} />
|
||||
);
|
||||
}
|
||||
|
||||
interface DiscriminatedUnionFieldProps extends FieldProps {
|
||||
discriminatorProp: string;
|
||||
}
|
||||
|
||||
function DiscriminatedUnionField({
|
||||
discriminatorProp,
|
||||
...props
|
||||
}: DiscriminatedUnionFieldProps) {
|
||||
const { schema, registry, formData, onChange, name } = props;
|
||||
const { fields, schemaUtils, formContext } = registry;
|
||||
const { SchemaField } = fields;
|
||||
const { nodeId } = formContext;
|
||||
|
||||
const field_id = props.fieldPathId.$id;
|
||||
|
||||
// Resolve variant schemas from $refs
|
||||
const variants = useRef(
|
||||
(schema.oneOf || []).map((opt: any) =>
|
||||
schemaUtils.retrieveSchema(opt, formData),
|
||||
),
|
||||
);
|
||||
|
||||
// Build dropdown options from variant titles and discriminator const values
|
||||
const enumOptions = variants.current.map((variant: any, index: number) => {
|
||||
const discValue = (variant.properties?.[discriminatorProp] as any)?.const;
|
||||
return {
|
||||
value: index,
|
||||
label: variant.title || discValue || `Option ${index + 1}`,
|
||||
discriminatorValue: discValue,
|
||||
};
|
||||
});
|
||||
|
||||
// Determine initial selected index from formData
|
||||
function getInitialIndex() {
|
||||
const currentDisc = formData?.[discriminatorProp];
|
||||
if (currentDisc) {
|
||||
const idx = enumOptions.findIndex(
|
||||
(o) => o.discriminatorValue === currentDisc,
|
||||
);
|
||||
if (idx >= 0) return idx;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
const [selectedIndex, setSelectedIndex] = useState(getInitialIndex);
|
||||
|
||||
// Generate handleId for sub-fields (same convention as AnyOfField)
|
||||
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
|
||||
const handleId = getHandleId({
|
||||
uiOptions,
|
||||
id: field_id + ANY_OF_FLAG,
|
||||
schema,
|
||||
});
|
||||
|
||||
const childUiSchema = updateUiOption(props.uiSchema, {
|
||||
handleId,
|
||||
label: false,
|
||||
fromAnyOf: true,
|
||||
});
|
||||
|
||||
// Get selected variant schema with discriminator property filtered out
|
||||
// and sub-fields inheriting the parent's advanced value
|
||||
const selectedVariant = variants.current[selectedIndex];
|
||||
const parentAdvanced = (schema as any).advanced;
|
||||
|
||||
function getFilteredSchema() {
|
||||
if (!selectedVariant?.properties) return selectedVariant;
|
||||
const filteredProperties: Record<string, any> = {};
|
||||
for (const [key, value] of Object.entries(selectedVariant.properties)) {
|
||||
if (key === discriminatorProp) continue;
|
||||
filteredProperties[key] =
|
||||
parentAdvanced !== undefined
|
||||
? { ...(value as any), advanced: parentAdvanced }
|
||||
: value;
|
||||
}
|
||||
return {
|
||||
...selectedVariant,
|
||||
properties: filteredProperties,
|
||||
required: (selectedVariant.required || []).filter(
|
||||
(r: string) => r !== discriminatorProp,
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
const filteredSchema = getFilteredSchema();
|
||||
|
||||
// Handle variant change
|
||||
function handleVariantChange(option?: string) {
|
||||
const newIndex = option !== undefined ? parseInt(option, 10) : -1;
|
||||
if (newIndex === selectedIndex || newIndex < 0) return;
|
||||
|
||||
const newVariant = variants.current[newIndex];
|
||||
const oldVariant = variants.current[selectedIndex];
|
||||
const discValue = (newVariant.properties?.[discriminatorProp] as any)
|
||||
?.const;
|
||||
|
||||
// Clean edges for this field
|
||||
const handlePrefix = cleanUpHandleId(field_id);
|
||||
useEdgeStore.getState().removeEdgesByHandlePrefix(nodeId, handlePrefix);
|
||||
|
||||
// Sanitize current data against old→new schema to preserve shared fields
|
||||
let newFormData = schemaUtils.sanitizeDataForNewSchema(
|
||||
newVariant,
|
||||
oldVariant,
|
||||
formData,
|
||||
);
|
||||
|
||||
// Fill in defaults for the new variant
|
||||
newFormData = schemaUtils.getDefaultFormState(
|
||||
newVariant,
|
||||
newFormData,
|
||||
"excludeObjectChildren",
|
||||
) as any;
|
||||
newFormData = { ...newFormData, [discriminatorProp]: discValue };
|
||||
|
||||
setSelectedIndex(newIndex);
|
||||
onChange(newFormData, props.fieldPathId.path, undefined, field_id);
|
||||
}
|
||||
|
||||
// Sync selectedIndex when formData discriminator changes externally
|
||||
// (e.g. undo/redo, loading saved state)
|
||||
const currentDiscValue = formData?.[discriminatorProp];
|
||||
useEffect(() => {
|
||||
const idx = currentDiscValue
|
||||
? enumOptions.findIndex((o) => o.discriminatorValue === currentDiscValue)
|
||||
: -1;
|
||||
|
||||
if (idx >= 0) {
|
||||
if (idx !== selectedIndex) setSelectedIndex(idx);
|
||||
} else if (enumOptions.length > 0 && selectedIndex !== 0) {
|
||||
// Unknown or cleared discriminator — full reset via same cleanup path
|
||||
handleVariantChange("0");
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentDiscValue]);
|
||||
|
||||
// Auto-set discriminator on initial render if missing
|
||||
useEffect(() => {
|
||||
const discValue = enumOptions[selectedIndex]?.discriminatorValue;
|
||||
if (discValue && formData?.[discriminatorProp] !== discValue) {
|
||||
onChange(
|
||||
{ ...formData, [discriminatorProp]: discValue },
|
||||
props.fieldPathId.path,
|
||||
undefined,
|
||||
field_id,
|
||||
);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
|
||||
|
||||
const selector = (
|
||||
<Widget
|
||||
id={field_id}
|
||||
name={`${name}__oneof_select`}
|
||||
schema={{ type: "number", default: 0 }}
|
||||
onChange={handleVariantChange}
|
||||
onBlur={props.onBlur}
|
||||
onFocus={props.onFocus}
|
||||
disabled={props.disabled || enumOptions.length === 0}
|
||||
multiple={false}
|
||||
value={selectedIndex}
|
||||
options={{ enumOptions }}
|
||||
registry={registry}
|
||||
placeholder={props.placeholder}
|
||||
autocomplete={props.autocomplete}
|
||||
className={cn("-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium")}
|
||||
autofocus={props.autofocus}
|
||||
label=""
|
||||
hideLabel={true}
|
||||
readonly={props.readonly}
|
||||
/>
|
||||
);
|
||||
|
||||
const DescriptionFieldTemplate = getTemplate(
|
||||
"DescriptionFieldTemplate",
|
||||
registry,
|
||||
uiOptions,
|
||||
);
|
||||
const description_id = descriptionId(props.fieldPathId ?? "");
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="body" className="line-clamp-1">
|
||||
{schema.title || name}
|
||||
</Text>
|
||||
<Text variant="small" className="mr-1 text-red-500">
|
||||
{props.required ? "*" : null}
|
||||
</Text>
|
||||
{selector}
|
||||
<DescriptionFieldTemplate
|
||||
id={description_id}
|
||||
description={schema.description || ""}
|
||||
schema={schema}
|
||||
registry={registry}
|
||||
/>
|
||||
</div>
|
||||
{filteredSchema && filteredSchema.type !== "null" && (
|
||||
<SchemaField
|
||||
{...props}
|
||||
schema={filteredSchema}
|
||||
uiSchema={childUiSchema}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -6,7 +6,11 @@ import {
|
||||
titleId,
|
||||
} from "@rjsf/utils";
|
||||
|
||||
import { isAnyOfChild, isAnyOfSchema } from "../../utils/schema-utils";
|
||||
import {
|
||||
isAnyOfChild,
|
||||
isAnyOfSchema,
|
||||
isOneOfSchema,
|
||||
} from "../../utils/schema-utils";
|
||||
import {
|
||||
cleanUpHandleId,
|
||||
getHandleId,
|
||||
@@ -82,12 +86,13 @@ export default function FieldTemplate(props: FieldTemplateProps) {
|
||||
const shouldDisplayLabel =
|
||||
displayLabel ||
|
||||
(schema.type === "boolean" && !isAnyOfChild(uiSchema as any));
|
||||
const shouldShowTitleSection = !isAnyOfSchema(schema) && !additional;
|
||||
const isUnionSchema = isAnyOfSchema(schema) || isOneOfSchema(schema);
|
||||
const shouldShowTitleSection = !isUnionSchema && !additional;
|
||||
|
||||
const shouldShowChildren =
|
||||
schema.type === "object" ||
|
||||
schema.type === "array" ||
|
||||
isAnyOfSchema(schema) ||
|
||||
isUnionSchema ||
|
||||
!isHandleConnected;
|
||||
|
||||
const isAdvancedField = (schema as any).advanced === true;
|
||||
@@ -95,8 +100,7 @@ export default function FieldTemplate(props: FieldTemplateProps) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const marginBottom =
|
||||
isPartOfAnyOf({ uiOptions }) || isAnyOfSchema(schema) ? 0 : 16;
|
||||
const marginBottom = isPartOfAnyOf({ uiOptions }) || isUnionSchema ? 0 : 16;
|
||||
|
||||
return (
|
||||
<WrapIfAdditionalTemplate
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
|
||||
import { isAnyOfSchema } from "../../utils/schema-utils";
|
||||
import { isAnyOfSchema, isOneOfSchema } from "../../utils/schema-utils";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cleanUpHandleId, isArrayItem } from "../../helpers";
|
||||
import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle";
|
||||
@@ -18,7 +18,7 @@ export default function TitleField(props: TitleFieldProps) {
|
||||
const { nodeId, showHandles } = registry.formContext;
|
||||
const uiOptions = getUiOptions(uiSchema);
|
||||
|
||||
const isAnyOf = isAnyOfSchema(schema);
|
||||
const isAnyOf = isAnyOfSchema(schema) || isOneOfSchema(schema);
|
||||
const { displayType, colorClass } = getTypeDisplayInfo(schema);
|
||||
const description_id = descriptionId(id);
|
||||
|
||||
|
||||
@@ -8,6 +8,14 @@ export function isAnyOfSchema(schema: RJSFSchema | undefined): boolean {
|
||||
);
|
||||
}
|
||||
|
||||
export function isOneOfSchema(schema: RJSFSchema | undefined): boolean {
|
||||
return (
|
||||
Array.isArray(schema?.oneOf) &&
|
||||
schema!.oneOf.length > 0 &&
|
||||
schema?.enum === undefined
|
||||
);
|
||||
}
|
||||
|
||||
export const isAnyOfChild = (
|
||||
uiSchema: UiSchema<any, RJSFSchema, any> | undefined,
|
||||
): boolean => {
|
||||
|
||||
Reference in New Issue
Block a user