Compare commits

...

27 Commits

Author SHA1 Message Date
Bentlybro
957ec038b8 chore: regenerate OpenAPI schema for new LLM endpoints 2026-03-16 15:44:53 +00:00
Bentlybro
9b39a662ee feat(platform): Add LLM registry public read API
Implements public GET endpoints for querying LLM models and providers - Part 3 of 6 in the incremental registry rollout.

**Endpoints:**
- GET /api/llm/models - List all models (filterable by enabled_only)
- GET /api/llm/providers - List providers with their models

**Design:**
- Uses in-memory registry from PR 2 (no DB queries)
- Fast reads from cache populated at startup
- Grouped by provider for easy UI rendering

**Response models:**
- LlmModel - model info with capabilities, costs, creator
- LlmProvider - provider with nested models
- LlmModelsResponse - list + total count
- LlmProvidersResponse - grouped by provider

**Authentication:**
- Requires user auth (requires_user dependency)
- Public within authenticated sessions

**Integration:**
- Registered in rest_api.py at /api prefix
- Tagged with v2 + llm for OpenAPI grouping

**What's NOT included (later PRs):**
- Admin write API (PR 4)
- Block integration (PR 5)
- Redis cache (PR 6)

Lines: ~180 total
Files: 4 (3 new, 1 modified)
Review time: < 10 minutes
2026-03-16 15:44:53 +00:00
Bentlybro
9b93a956b4 style: fix trailing whitespace in registry.py 2026-03-16 15:44:45 +00:00
Bentlybro
b236719bbf fix(startup): handle missing AgentNode table in migrate_llm_models
Tests fail with 'relation "platform.AgentNode" does not exist' because
migrate_llm_models() runs during startup and queries a table that doesn't
exist in fresh test databases.

This is an existing bug in the codebase - the function has no error handling.

Wrap the call in try/except to gracefully handle test environments where
the AgentNode table hasn't been created yet.
2026-03-16 15:12:25 +00:00
Bentlybro
4f286f510f refactor: address CodeRabbit/Majdyz review feedback
- Fix ModelMetadata duplicate type collision by importing from blocks.llm
- Remove _json_to_dict helper, use dict() inline
- Add warning when Provider relation is missing (data corruption indicator)
- Optimize get_default_model_slug with next() (single sort pass)
- Optimize _build_schema_options to use list comprehension
- Move llm_registry import to top-level in rest_api.py
- Ensure max_output_tokens falls back to context_window when null

All critical and quick-win issues addressed.
2026-03-16 14:55:39 +00:00
Bentlybro
b1595d871d fix: address Sentry/CodeRabbit critical and major issues
**CRITICAL FIX - ModelMetadata instantiation:**
- Removed non-existent 'supports_vision' argument
- Added required fields: display_name, provider_name, creator_name, price_tier
- Handle nullable DB fields (Creator, priceTier, maxOutputTokens) safely
- Fallback: creator_name='Unknown' if no Creator, price_tier=1 if invalid

**MAJOR FIX - Preserve pricing unit:**
- Added 'unit' field to RegistryModelCost dataclass
- Prevents RUN vs TOKENS ambiguity in cached costs
- Convert Prisma enum to string when building cost objects

**MAJOR FIX - Deterministic default model:**
- Sort recommended models by display_name before selection
- Prevents non-deterministic results when multiple models are recommended
- Ensures consistent default across refreshes

**STARTUP IMPROVEMENT:**
- Added comment: graceful fallback OK for now (no blocks use registry yet)
- Will be stricter in PR #5 when block integration lands
- Added success log message for registry refresh

Fixes identified by Sentry (critical TypeError) and CodeRabbit review.
2026-03-16 14:55:39 +00:00
Bentlybro
29ab7f2d9c feat(platform): Add LLM registry core - DB layer + in-memory cache
Implements the registry core for dynamic LLM model management:

**DB Layer:**
- Fetch models with provider, costs, and creator relations
- Prisma query with includes for related data
- Convert DB records to typed dataclasses

**In-memory Cache:**
- Global dict for fast model lookups
- Atomic cache refresh with lock protection
- Schema options generation for UI dropdowns

**Public API:**
- get_model(slug) - lookup by slug
- get_all_models() - all models (including disabled)
- get_enabled_models() - enabled models only
- get_schema_options() - UI dropdown data
- get_default_model_slug() - recommended or first enabled
- refresh_llm_registry() - manual refresh trigger

**Integration:**
- Refresh at API startup (before block init)
- Graceful fallback if registry unavailable
- Enables blocks to consume registry data

**Models:**
- RegistryModel - full model with metadata
- RegistryModelCost - pricing configuration
- RegistryModelCreator - model creator info
- ModelMetadata - context window, capabilities

**Next PRs:**
- PR #3: Public read API (GET endpoints)
- PR #4: Admin write API (POST/PATCH/DELETE)
- PR #5: Block integration (update LLM block)
- PR #6: Redis cache (solve thundering herd)

Lines: ~230 (registry.py ~210, __init__.py ~30, model.py from draft)
Files: 4 (3 new, 1 modified)
2026-03-16 14:55:39 +00:00
Bentlybro
784936b323 revert: undo changes to graph.py
Reverting migrate_llm_models modifications per request.
Back to dev baseline for this file.
2026-03-16 14:55:39 +00:00
Bentlybro
f2ae38a1a7 fix(schema): address Majdyz review feedback
- Add FK constraints on LlmModelMigration (sourceModelSlug, targetModelSlug → LlmModel.slug)
- Remove unused @@index([credentialProvider]) on LlmModelCost
- Remove redundant @@index([isReverted]) on LlmModelMigration (covered by composite)
- Add documentation for credentialProvider field explaining its purpose
- Add reverse relation fields to LlmModel (SourceMigrations, TargetMigrations)

Fixes data integrity: typos in migration slugs now caught at DB level.
2026-03-16 14:52:19 +00:00
Bently
2ccfb4e4c1 Merge branch 'dev' into feat/llm-registry-schema 2026-03-10 17:52:01 +00:00
Otto
5641cdd3ca fix(backend): update test patches for validate_url → validate_url_host rename (#12358)
bfb843a renamed `validate_url` to `validate_url_host` in
`agent_browser`, `run_mcp_tool`, and MCP routes, but the corresponding
test files still patched the old name, causing `AttributeError` in CI.

Updates all mock patch targets and assertions across 3 test files:
- `agent_browser_test.py`
- `test_run_mcp_tool.py`  
- `mcp/test_routes.py`

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-10 17:22:11 +00:00
Bentlybro
c65e5c957a fix: isort import order 2026-03-10 16:43:34 +00:00
Bentlybro
54355a691b fix: use execute_raw_with_schema for proper multi-schema support
Per Sentry feedback: db.execute_raw ignores connection string's ?schema=
parameter and defaults to 'public' schema. This breaks in multi-schema setups.

Changes:
- Import execute_raw_with_schema from .db
- Use {schema_prefix} placeholder in query
- Call execute_raw_with_schema instead of db.execute_raw

This matches the pattern used in fix_llm_provider_credentials and other
schema-aware migrations. Works in both CI (public schema) and local
(platform schema from connection string).
2026-03-10 16:25:12 +00:00
Bentlybro
3cafa49c4c fix: remove hardcoded schema prefix from migrate_llm_models query
The raw SQL query in migrate_llm_models() hardcoded platform."AgentNode"
which fails in CI where tables are in 'public' schema (not 'platform').

This code exists in dev but only runs when LLM registry has data. With our
new schema, the migration tries to run at startup and fails in CI.

Changed: UPDATE platform."AgentNode" -> UPDATE "AgentNode"

Matches pattern of all other migrations - let connection string's default
schema handle routing.
2026-03-10 16:19:57 +00:00
Bentlybro
ded002a406 fix: remove CREATE SCHEMA to match CI environment
CI uses schema "public" as default (not "platform"), so creating
a platform schema then tables without prefix puts tables in public
but Prisma looks in platform.

Existing migrations don't create schema - they rely on connection
string's default. Remove CREATE SCHEMA IF NOT EXISTS to match.
2026-03-10 15:57:26 +00:00
Bentlybro
4fdf89c3be fix: remove schema prefix from migration SQL to match existing pattern
CI failing with 'relation "platform.AgentNode" does not exist' because
Prisma generates queries differently when tables are created with
explicit schema prefixes.

Existing AutoGPT migrations use:
  CREATE TABLE "AgentNode" (...)

Not:
  CREATE TABLE "platform"."AgentNode" (...)

The connection string's ?schema=platform handles schema selection,
so explicit prefixes aren't needed and cause compatibility issues.

Changes:
- Remove all "platform". prefixes from:
  * CREATE TYPE statements
  * CREATE TABLE statements
  * CREATE INDEX statements
  * ALTER TABLE statements
  * REFERENCES clauses in foreign keys

Now matches existing migration pattern exactly.
2026-03-10 15:38:41 +00:00
Bentlybro
d816bd739f fix: add partial unique indexes for data integrity
Per CodeRabbit feedback - fix 2 actual bugs:

1. Prevent multiple active migrations per source model
   - Add partial unique index: UNIQUE (sourceModelSlug) WHERE isReverted = false
   - Prevents ambiguous routing when resolving migrations

2. Allow both default and credential-specific costs
   - Remove @@unique([llmModelId, credentialProvider, unit])
   - Add 2 partial unique indexes:
     * UNIQUE (llmModelId, provider, unit) WHERE credentialId IS NULL (defaults)
     * UNIQUE (llmModelId, provider, credentialId, unit) WHERE credentialId IS NOT NULL (overrides)
   - Enables provider-level default costs + per-credential overrides

Schema comments document that these constraints exist in migration SQL.
2026-03-10 15:08:44 +00:00
Otto
bfb843a56e Merge commit from fork
* Fix SSRF via user-controlled ollama_host field

Validate ollama_host against BLOCKED_IP_NETWORKS before passing to
ollama.AsyncClient(). The server-configured default (env: OLLAMA_HOST)
is allowed without validation; user-supplied values that differ are
checked for private/internal IP resolution.

Fixes GHSA-6jx2-4h7q-3fx3

* Generalize validate_ollama_host to validate_host; fix description line length

* Rename to validate_untrusted_host with whitelist parameter

* Apply PR suggestion: include whitelist in error message; run formatting

* Move whitelist check after URL normalization; match on netloc

* revert unrelated formatting changes

* Dedup validate_url and validate_untrusted_host; normalize whitelist

* Move _resolve_and_check_blocked after calling functions

* dedup and clean up

* make trusted_hostnames truly optional

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-03-10 15:51:58 +01:00
Bentlybro
6a16376323 fix: remove multiSchema - follow existing AutoGPT pattern
Remove unnecessary multiSchema configuration that broke existing models.

AutoGPT uses connection string's ?schema=platform parameter as default,
not Prisma's multiSchema feature. Existing models (User, AgentGraph, etc.)
have no @@schema() directives and work fine.

Changes:
- Remove schemas = ["platform", "public"] from datasource
- Remove "multiSchema" from previewFeatures
- Remove all @@schema() directives from LLM models and enum

Migration SQL already creates tables in platform schema explicitly
(CREATE TABLE "platform"."LlmProvider" etc.) which is correct.

This matches the existing pattern used throughout the codebase.
2026-03-10 14:49:23 +00:00
Bentlybro
ed7b02ffb1 fix: address CodeRabbit design feedback
Per CodeRabbit review:

1. **Safety: Change capability defaults false → safer for partial seeding**
   - supportsTools: true → false
   - supportsJsonOutput: true → false
   - Prevents partially-seeded rows from being assumed capable

2. **Clarity: Rename supportsParallelTool → supportsParallelToolCalls**
   - More explicit about what the field represents

3. **Performance: Remove redundant indexes**
   - Drop @@index([llmModelId]) - covered by unique constraint
   - Drop @@index([sourceModelSlug]) - covered by composite index
   - Reduces write overhead and storage

4. **Documentation: Acknowledge customCreditCost limitation**
   - It's unit-agnostic (doesn't distinguish RUN vs TOKENS)
   - Noted as TODO for follow-up PR with proper unit-aware override

Schema + migration both updated to match.
2026-03-10 14:27:42 +00:00
Bentlybro
d064198dd1 fix: add @@schema("platform") to LlmCostUnit enum
Sentry caught this - enums also need @@schema directive with multiSchema enabled.
Without it, Prisma looks for enum in public schema but it's created in platform.
2026-03-10 14:23:24 +00:00
Bentlybro
01ad033b2b feat: add database CHECK constraints for data integrity
Per CodeRabbit feedback - enforce numeric domain rules at DB level:

Migration:
- priceTier: CHECK (priceTier BETWEEN 1 AND 3)
- creditCost: CHECK (creditCost >= 0)
- nodeCount: CHECK (nodeCount >= 0)
- customCreditCost: CHECK (customCreditCost IS NULL OR customCreditCost >= 0)

Schema comments:
- Document constraints inline for developer visibility

Prevents invalid data (negative costs, out-of-range tiers) from
entering the database, matching backend/blocks/llm.py contract.
2026-03-10 14:20:07 +00:00
Bentlybro
56bcbda054 fix: use @@schema() instead of @@map() for platform schema + create schema in migration
Critical fixes from PR review:

1. Replace @@map("platform.ModelName") with @@schema("platform")
   - Sentry correctly identified: Prisma was looking for literal table "platform.LlmProvider" with dot
   - Proper syntax: enable multiSchema feature + use @@schema directive

2. Create platform schema in migration
   - CI failed: schema "platform" does not exist
   - Add CREATE SCHEMA IF NOT EXISTS at start of migration

Schema changes:
- datasource: add schemas = ["platform", "public"]
- generator: add "multiSchema" to previewFeatures
- All 5 models: @@map() → @@schema("platform")

Migration changes:
- Add CREATE SCHEMA IF NOT EXISTS "platform" before enum creation

Fixes CI failure and Sentry-identified bug.
2026-03-10 14:15:35 +00:00
Abhimanyu Yadav
684845d946 fix(frontend/builder): handle discriminated unions and improve node layout (#12354)
## Summary
- **Discriminated union support (oneOf)**: Added a new `OneOfField`
component that properly
renders Pydantic discriminated unions. Hides the unusable parent object
handle, auto-populates
the discriminator value, shows a dropdown with variant titles (e.g.,
"Username" / "UserId"), and
filters out the internal discriminator field from the form.
Non-discriminated `oneOf` schemas
  fall back to existing `AnyOfField` behavior.
- **Collapsible object outputs**: Object-type outputs with nested keys
(e.g.,
`PersonLookupResponse.Url`, `PersonLookupResponse.profile`) are now
collapsed by default behind a
caret toggle. Nested keys show short names instead of the full
`Parent.Key` prefix.
- **Node layout cleanup**: Removed excessive bottom margin (`mb-6`) from
`FormRenderer`, hide the
Advanced toggle when no advanced fields exist, and add rounded bottom
corners on OUTPUT-type
  blocks.

<img width="440" height="427" alt="Screenshot 2026-03-10 at 11 31 55 AM"
src="https://github.com/user-attachments/assets/06cc5414-4e02-4371-bdeb-1695e7cb2c97"
/>
<img width="371" height="320" alt="Screenshot 2026-03-10 at 11 36 52 AM"
src="https://github.com/user-attachments/assets/1a55f87a-c602-4f4d-b91b-6e49f810e5d5"
/>

  ## Test plan
- [x] Add a Twitter Get User block — verify "Identifier" shows a
dropdown (Username/UserId) with
no unusable parent handle, discriminator field is hidden, and the block
can run without staying
  INCOMPLETE
- [x] Add any block with object outputs (e.g., PersonLookupResponse) —
verify nested keys are
  collapsed by default and expand on click with short labels
- [x] Verify blocks without advanced fields don't show the Advanced
toggle
- [x] Verify existing `anyOf` schemas (optional types, 3+ variant
unions) still render correctly
  - [x] Check OUTPUT-type blocks have rounded bottom corners

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: eureka928 <meobius123@gmail.com>
2026-03-10 14:13:32 +00:00
Bentlybro
d40efc6056 feat(platform): Add LLM registry database schema
Add Prisma schema and migration for dynamic LLM model registry:

Schema additions:
- LlmProvider: Registry of LLM providers (OpenAI, Anthropic, etc.)
- LlmModel: Individual models with capabilities and metadata
- LlmModelCost: Per-model pricing configuration
- LlmModelCreator: Model creators/trainers (OpenAI, Meta, etc.)
- LlmModelMigration: Track model migrations and reverts
- LlmCostUnit enum: RUN vs TOKENS pricing units

Key features:
- Model-specific capabilities (tools, JSON, reasoning, parallel calls)
- Flexible creator/provider separation (e.g., Meta model via Hugging Face)
- Migration tracking with custom pricing overrides
- Indexes for performance on common queries

Part 1 of incremental LLM registry implementation.
Refs: Draft PR #11699
2026-03-10 13:22:05 +00:00
Bently
6a6b23c2e1 fix(frontend): Remove unused Otto Server Action causing 107K+ errors (#12336)
## Summary

Fixes [OPEN-3025](https://linear.app/autogpt/issue/OPEN-3025) —
**107,571+ Server Action errors** in production

Removes the orphaned `askOtto` Server Action that was left behind after
the Otto chat widget removal in PR #12082.

## Problem

Next.js Server Actions that are never imported are excluded from the
server manifest. Old client bundles still reference the action ID,
causing "not found" errors.

**Sentry impact:**
- **BUILDER-3BN:** 107,571 events
- **BUILDER-729:** 285 events  
- **BUILDER-3QH:** 1,611 events
- **36+ users affected**

## Root Cause

1. **Mar 2025:** Otto widget added to `/build` page with `askOtto`
Server Action
2. **Feb 2026:** Otto widget removed (PR #12082), but `actions.ts` left
behind
3. **Result:** Dead code → not in manifest → errors

## Evidence

```bash
# Zero imports across frontend:
grep -r "askOtto" src/ --exclude="actions.ts"
# → No results

# Server manifest missing the action:
cat .next/server/server-reference-manifest.json
# → Only includes login/supabase actions, NOT build/actions
```

## Changes

-  Delete
`autogpt_platform/frontend/src/app/(platform)/build/actions.ts`

## Testing

1. Verify no imports of `askOtto` in codebase 
2. Check Sentry for error drop after deploy
3. Monitor for new "Server Action not found" errors

## Checklist

- [x] Dead code confirmed (zero imports)
- [x] Sentry issues documented
- [x] Clear commit message with context
2026-03-10 09:03:38 +00:00
Dream
d0a1d72e8a fix(frontend/builder): batch undo history for cascading operations (#12344)
## Summary

Fixes undo in the Builder not working correctly when deleting nodes.
When a node is deleted, React Flow fires `onNodesChange` (node removal)
and `onEdgesChange` (cascading edge cleanup) as separate callbacks —
each independently pushing to the undo history stack. This creates
intermediate states that break undo:

- Single undo restores a partial state (e.g. edges pointing to a deleted
node)
- Multiple undos required to fully restore the graph
- Redo also produces inconsistent states

Resolves #10999

### Changes 🏗️

- **`historyStore.ts`** — Added microtask-based batching to
`pushState()`. Multiple calls within the same synchronous execution
(same event loop tick) are coalesced into a single history entry,
keeping only the first pre-change snapshot. Uses `queueMicrotask` so all
cascading store updates from a single user action settle before the
history entry is committed.
- Reset `pendingState` in `initializeHistory()` and `clear()` to prevent
stale batched state from leaking across graph loads or navigation.

**Side benefit:** Copy/paste operations that add multiple nodes and
edges now also produce a single history entry instead of one per
node/edge.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Place 3 blocks A, B, C and connect A→B→C
  - [x] Delete block C (removes node + cascading edge B→C)
  - [x] Delete connection A→B
  - [x] Undo — connection A→B restored (single undo, not multiple)
  - [x] Undo — block C and connection B→C restored
  - [x] Redo — block C removed again with its connections
- [x] Copy/paste multiple connected blocks — single undo reverts entire
paste

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-03-10 04:55:07 +00:00
33 changed files with 10241 additions and 2108 deletions

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
await validate_url_host(auth_server_url)
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
await validate_url_host(registration_endpoint)
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
"""Verify that validate_url_host is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -37,8 +37,10 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.llm_registry
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
@@ -117,11 +119,30 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
# When block integration lands, this should fail hard or skip block initialization
try:
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry refreshed successfully at startup")
except Exception as e:
logger.warning(
f"Failed to refresh LLM registry at startup: {e}. "
"Blocks will initialize with empty registry."
)
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
try:
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
except Exception as e:
logger.warning(
f"Failed to migrate LLM models at startup: {e}. "
"This is expected in test environments without AgentNode table."
)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
@@ -348,6 +369,11 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.server.v2.llm.router,
tags=["v2", "llm"],
prefix="/api",
)
app.mount("/external-api", external_api)

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url(input_data.url, [])
parsed_url, _, _ = await validate_url_host(input_data.url)
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -34,8 +34,11 @@ from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -805,6 +808,11 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]

View File

@@ -33,7 +33,7 @@ import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from backend.util.request import validate_url_host
from .base import BaseTool
from .models import (
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
# ---------------------------------------------------------------------------
# SSRF protection via shared validate_url (backend.util.request)
# SSRF protection via shared validate_url_host (backend.util.request)
# ---------------------------------------------------------------------------
# Patch target: validate_url is imported directly into agent_browser's module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
# Patch target: validate_url_host is imported directly into agent_browser's
# module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
class TestSsrfViaValidateUrl:
"""Verify that browser_navigate uses validate_url for SSRF protection.
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
We mock validate_url itself (not the low-level socket) so these tests
We mock validate_url_host itself (not the low-level socket) so these tests
exercise the integration point, not the internals of request.py
(which has its own thorough test suite in request_test.py).
"""
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
@pytest.mark.asyncio
async def test_blocked_ip_returns_blocked_url_error(self):
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.side_effect = ValueError(
"Access to blocked IP 10.0.0.1 is not allowed."
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
assert result.error == "blocked_url"
@pytest.mark.asyncio
async def test_validate_url_called_with_empty_trusted_origins(self):
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
async def test_validate_url_host_called_without_trusted_hostnames(self):
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = (object(), False, ["1.2.3.4"])
with patch(
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
session=self.session,
url="https://example.com",
)
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
mock_validate.assert_called_once_with("https://example.com")
# ---------------------------------------------------------------------------

View File

@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
from backend.util.request import HTTPClientError, validate_url
from backend.util.request import HTTPClientError, validate_url_host
from .base import BaseTool
from .models import (
@@ -144,7 +144,7 @@ class RunMCPToolTool(BaseTool):
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
try:
await validate_url(server_url, trusted_origins=[])
await validate_url_host(server_url)
except ValueError as e:
msg = str(e)
if "Unable to resolve" in msg or "No IP addresses" in msg:

View File

@@ -100,7 +100,7 @@ async def test_ssrf_blocked_url_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
"backend.copilot.tools.run_mcp_tool.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -138,7 +138,7 @@ async def test_non_dict_tool_arguments_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
"backend.copilot.tools.run_mcp_tool.validate_url_host",
new_callable=AsyncMock,
):
with patch(
@@ -171,7 +171,7 @@ async def test_discover_tools_returns_discovered_response():
mock_tools = _make_tool_list("fetch", "search")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -208,7 +208,7 @@ async def test_discover_tools_with_credentials():
mock_tools = _make_tool_list("push_notification")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -249,7 +249,7 @@ async def test_execute_tool_returns_output_response():
text_result = "# Example Domain\nThis domain is for examples."
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -285,7 +285,7 @@ async def test_execute_tool_parses_json_result():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -320,7 +320,7 @@ async def test_execute_tool_image_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -359,7 +359,7 @@ async def test_execute_tool_resource_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -399,7 +399,7 @@ async def test_execute_tool_multi_item_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -437,7 +437,7 @@ async def test_execute_tool_empty_content_returns_none():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -470,7 +470,7 @@ async def test_execute_tool_returns_error_on_tool_failure():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -512,7 +512,7 @@ async def test_auth_required_without_creds_returns_setup_requirements():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -555,7 +555,7 @@ async def test_auth_error_with_existing_creds_returns_error():
mock_creds.access_token = SecretStr("stale-token")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -589,7 +589,7 @@ async def test_mcp_client_error_returns_error_response():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -621,7 +621,7 @@ async def test_unexpected_exception_returns_generic_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -719,7 +719,7 @@ async def test_credential_lookup_normalizes_trailing_slash():
url_with_slash = "https://mcp.example.com/mcp/"
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",

View File

@@ -0,0 +1,31 @@
"""LLM Registry - Dynamic model management system."""
from .model import ModelMetadata
from .registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_all_models,
get_default_model_slug,
get_enabled_models,
get_model,
get_schema_options,
refresh_llm_registry,
)
__all__ = [
# Models
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Functions
"refresh_llm_registry",
"get_model",
"get_all_models",
"get_enabled_models",
"get_schema_options",
"get_default_model_slug",
"get_all_model_slugs_for_validation",
]

View File

@@ -0,0 +1,9 @@
"""Type definitions for LLM model metadata.
Re-exports ModelMetadata from blocks.llm to avoid type collision.
In PR #5 (block integration), this will become the canonical location.
"""
from backend.blocks.llm import ModelMetadata
__all__ = ["ModelMetadata"]

View File

@@ -0,0 +1,240 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
# In-memory cache (will be replaced with Redis in PR #6)
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_lock = asyncio.Lock()
async def refresh_llm_registry() -> None:
"""
Refresh the LLM registry from the database.
Fetches all models with their costs, providers, and creators,
then updates the in-memory cache.
"""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.info(f"Fetched {len(records)} LLM models from database")
# Build model instances
new_models: dict[str, RegistryModel] = {}
for record in records:
# Parse costs
costs = tuple(
RegistryModelCost(
unit=str(cost.unit), # Convert enum to string
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
# Parse creator
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
# Parse capabilities
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display = (
record.Provider.displayName
if record.Provider
else record.providerId
)
# Extract creator name (fallback to "Unknown" if no creator)
creator_name = (
record.Creator.displayName if record.Creator else "Unknown"
)
# Price tier defaults to 1 if not set
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
price_tier=price_tier,
)
# Create model instance
model = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
new_models[record.slug] = model
# Atomic swap
global _dynamic_models, _schema_options
_dynamic_models = new_models
_schema_options = _build_schema_options()
logger.info(
f"LLM registry refreshed: {len(_dynamic_models)} models, "
f"{len(_schema_options)} schema options"
)
except Exception as e:
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
raise
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
return [
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
for model in sorted(
_dynamic_models.values(), key=lambda m: m.display_name.lower()
)
if model.is_enabled
]
def get_model(slug: str) -> RegistryModel | None:
"""Get a model by slug from the registry."""
return _dynamic_models.get(slug)
def get_all_models() -> list[RegistryModel]:
"""Get all models from the registry (including disabled)."""
return list(_dynamic_models.values())
def get_enabled_models() -> list[RegistryModel]:
"""Get only enabled models from the registry."""
return [model for model in _dynamic_models.values() if model.is_enabled]
def get_schema_options() -> list[dict[str, str]]:
"""Get schema options for model selection dropdown (enabled models only)."""
return _schema_options
def get_default_model_slug() -> str | None:
"""Get the default model slug (first recommended, or first enabled)."""
# Sort once and use next() to short-circuit on first match
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
# Prefer recommended models
recommended = next(
(m.slug for m in models if m.is_recommended and m.is_enabled), None
)
if recommended:
return recommended
# Fallback to first enabled model
return next((m.slug for m in models if m.is_enabled), None)
def get_all_model_slugs_for_validation() -> list[str]:
"""
Get all model slugs for validation (enables migrate_llm_models to work).
Returns slugs for enabled models only.
"""
return [model.slug for model in _dynamic_models.values() if model.is_enabled]

View File

@@ -0,0 +1,5 @@
"""LLM registry public API."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,67 @@
"""Pydantic models for LLM registry public API."""
from __future__ import annotations
from typing import Any
import pydantic
class LlmModelCost(pydantic.BaseModel):
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int = pydantic.Field(ge=0)
credential_provider: str
credential_id: str | None = None
credential_type: str | None = None
currency: str | None = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model."""
id: str
name: str
display_name: str
description: str | None = None
website_url: str | None = None
logo_url: str | None = None
class LlmModel(pydantic.BaseModel):
"""Public-facing LLM model information."""
slug: str
display_name: str
description: str | None = None
provider_name: str
creator: LlmModelCreator | None = None
context_window: int
max_output_tokens: int | None = None
price_tier: int # 1=cheapest, 2=medium, 3=expensive
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
"""Provider with its enabled models."""
name: str
display_name: str
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmModelsResponse(pydantic.BaseModel):
"""Response for GET /llm/models."""
models: list[LlmModel]
total: int
class LlmProvidersResponse(pydantic.BaseModel):
"""Response for GET /llm/providers."""
providers: list[LlmProvider]

View File

@@ -0,0 +1,141 @@
"""Public read-only API for LLM registry."""
import autogpt_libs.auth
import fastapi
from backend.data.llm_registry import (
RegistryModelCreator,
get_all_models,
get_enabled_models,
)
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
def _map_creator(
creator: RegistryModelCreator | None,
) -> llm_model.LlmModelCreator | None:
"""Convert registry creator to API model."""
if not creator:
return None
return llm_model.LlmModelCreator(
id=creator.id,
name=creator.name,
display_name=creator.display_name,
description=creator.description,
website_url=creator.website_url,
logo_url=creator.logo_url,
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models(
enabled_only: bool = fastapi.Query(
default=True, description="Only return enabled models"
),
):
"""
List all LLM models available to users.
Returns models from the in-memory registry cache.
Use enabled_only=true to filter to only enabled models (default).
"""
# Get models from in-memory registry
registry_models = get_enabled_models() if enabled_only else get_all_models()
# Map to API response models
models = [
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in registry_models
]
return llm_model.LlmModelsResponse(models=models, total=len(models))
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""
List all LLM providers with their enabled models.
Groups enabled models by provider from the in-memory registry.
"""
# Get all enabled models and group by provider
registry_models = get_enabled_models()
# Group models by provider
provider_map: dict[str, list] = {}
for model in registry_models:
provider_key = model.metadata.provider
if provider_key not in provider_map:
provider_map[provider_key] = []
provider_map[provider_key].append(model)
# Build provider responses
providers = []
for provider_key, models in sorted(provider_map.items()):
# Use the first model's provider display name
display_name = models[0].provider_display_name if models else provider_key
providers.append(
llm_model.LlmProvider(
name=provider_key,
display_name=display_name,
models=[
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in sorted(models, key=lambda m: m.display_name)
],
)
)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -144,76 +144,106 @@ async def _resolve_host(hostname: str) -> list[str]:
return ip_addresses
async def validate_url(
url: str, trusted_origins: list[str]
async def validate_url_host(
url: str, trusted_hostnames: Optional[list[str]] = None
) -> tuple[URL, bool, list[str]]:
"""
Validates the URL to prevent SSRF attacks by ensuring it does not point
to a private, link-local, or otherwise blocked IP address — unless
Validates a (URL's) host string to prevent SSRF attacks by ensuring it does not
point to a private, link-local, or otherwise blocked IP address — unless
the hostname is explicitly trusted.
Hosts in `trusted_hostnames` are permitted without checks.
All other hosts are resolved and checked against `BLOCKED_IP_NETWORKS`.
Params:
url: A hostname, netloc, or URL to validate.
If no scheme is included, `http://` is assumed.
trusted_hostnames: A list of hostnames that don't require validation.
Raises:
ValueError:
- if the URL has a disallowed URL scheme
- if the URL/host string can't be parsed
- if the hostname contains invalid or unsupported (non-ASCII) characters
- if the host resolves to a blocked IP
Returns:
str: The validated, canonicalized, parsed URL
is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted
1. The validated, canonicalized, parsed host/URL,
with hostname ASCII-safe encoding
2. Whether the host is trusted (based on the passed `trusted_hostnames`).
3. List of resolved IP addresses for the host; empty if the host is trusted.
"""
parsed = parse_url(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
f"URL scheme '{parsed.scheme}' is not allowed; allowed schemes: "
f"{', '.join(ALLOWED_SCHEMES)}"
)
# Validate and IDNA encode hostname
if not parsed.hostname:
raise ValueError("Invalid URL: No hostname found.")
raise ValueError(f"Invalid host/URL; no host in parse result: {url}")
# IDNA encode to prevent Unicode domain attacks
try:
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
except idna.IDNAError:
raise ValueError("Invalid hostname with unsupported characters.")
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
# Check hostname characters
if not HOSTNAME_REGEX.match(ascii_hostname):
raise ValueError("Hostname contains invalid characters.")
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
# Check if hostname is trusted
is_trusted = ascii_hostname in trusted_origins
# If not trusted, validate IP addresses
ip_addresses: list[str] = []
if not is_trusted:
# Resolve all IP addresses for the hostname
ip_addresses = await _resolve_host(ascii_hostname)
# Block any IP address that belongs to a blocked range
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {ascii_hostname} is not allowed."
)
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
netloc = ascii_hostname
if parsed.port:
netloc = f"{ascii_hostname}:{parsed.port}"
return (
URL(
parsed.scheme,
netloc,
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
),
is_trusted,
ip_addresses,
# Re-create parsed URL object with IDNA-encoded hostname
parsed = URL(
parsed.scheme,
(ascii_hostname if parsed.port is None else f"{ascii_hostname}:{parsed.port}"),
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
)
is_trusted = trusted_hostnames and any(
matches_allowed_host(parsed, allowed)
for allowed in (
# Normalize + parse allowlist entries the same way for consistent matching
parse_url(w)
for w in trusted_hostnames
)
)
if is_trusted:
return parsed, True, []
# If not allowlisted, go ahead with host resolution and IP target check
return parsed, False, await _resolve_and_check_blocked(ascii_hostname)
def matches_allowed_host(url: URL, allowed: URL) -> bool:
if url.hostname != allowed.hostname:
return False
# Allow any port if not explicitly specified in the allowlist
if allowed.port is None:
return True
return url.port == allowed.port
async def _resolve_and_check_blocked(hostname: str) -> list[str]:
"""
Resolves hostname to IPs and raises ValueError if any resolve to
a blocked network. Returns the list of resolved IP addresses.
"""
ip_addresses = await _resolve_host(hostname)
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {hostname} is not allowed."
)
return ip_addresses
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
@@ -352,7 +382,7 @@ class Requests:
):
self.trusted_origins = []
for url in trusted_origins or []:
hostname = urlparse(url).hostname
hostname = parse_url(url).netloc # {host}[:{port}]
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname of {url}")
self.trusted_origins.append(hostname)
@@ -450,7 +480,7 @@ class Requests:
data = form
# Validate URL and get trust status
parsed_url, is_trusted, ip_addresses = await validate_url(
parsed_url, is_trusted, ip_addresses = await validate_url_host(
url, self.trusted_origins
)
@@ -503,7 +533,6 @@ class Requests:
json=json,
**kwargs,
) as response:
if self.raise_for_status:
try:
response.raise_for_status()

View File

@@ -1,7 +1,7 @@
import pytest
from aiohttp import web
from backend.util.request import pin_url, validate_url
from backend.util.request import pin_url, validate_url_host
@pytest.mark.parametrize(
@@ -60,9 +60,9 @@ async def test_validate_url_no_dns_rebinding(
):
if should_raise:
with pytest.raises(ValueError):
await validate_url(raw_url, trusted_origins)
await validate_url_host(raw_url, trusted_origins)
else:
validated_url, _, _ = await validate_url(raw_url, trusted_origins)
validated_url, _, _ = await validate_url_host(raw_url, trusted_origins)
assert validated_url.geturl() == expected_value
@@ -101,10 +101,10 @@ async def test_dns_rebinding_fix(
if expect_error:
# If any IP is blocked, we expect a ValueError
with pytest.raises(ValueError):
url, _, ip_addresses = await validate_url(hostname, [])
url, _, ip_addresses = await validate_url_host(hostname)
pin_url(url, ip_addresses)
else:
url, _, ip_addresses = await validate_url(hostname, [])
url, _, ip_addresses = await validate_url_host(hostname)
pinned_url = pin_url(url, ip_addresses).geturl()
# The pinned_url should contain the first valid IP
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")

View File

@@ -89,6 +89,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=500,
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
)
ollama_host: str = Field(
default="localhost:11434",
description="Default Ollama host; exempted from SSRF checks.",
)
pyro_host: str = Field(
default="localhost",
description="The default hostname of the Pyro server.",

View File

@@ -0,0 +1,148 @@
-- CreateEnum
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
-- CreateTable
CREATE TABLE "LlmProvider" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"defaultCredentialProvider" TEXT,
"defaultCredentialId" TEXT,
"defaultCredentialType" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelCreator" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"websiteUrl" TEXT,
"logoUrl" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModel" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"slug" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"providerId" TEXT NOT NULL,
"creatorId" TEXT,
"contextWindow" INTEGER NOT NULL,
"maxOutputTokens" INTEGER,
"priceTier" INTEGER NOT NULL DEFAULT 1,
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
"capabilities" JSONB NOT NULL DEFAULT '{}',
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelCost" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
"creditCost" INTEGER NOT NULL,
"credentialProvider" TEXT NOT NULL,
"credentialId" TEXT,
"credentialType" TEXT,
"currency" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
"llmModelId" TEXT NOT NULL,
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelMigration" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"sourceModelSlug" TEXT NOT NULL,
"targetModelSlug" TEXT NOT NULL,
"reason" TEXT,
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
"nodeCount" INTEGER NOT NULL,
"customCreditCost" INTEGER,
"isReverted" BOOLEAN NOT NULL DEFAULT false,
"revertedAt" TIMESTAMP(3),
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
-- CreateIndex
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
-- CreateIndex
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- CreateIndex (partial unique for default costs - no specific credential)
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
-- CreateIndex (partial unique for credential-specific costs)
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
-- CreateIndex
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
-- CreateIndex (partial unique to prevent multiple active migrations per source)
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddCheckConstraints (enforce data integrity)
ALTER TABLE "LlmModel"
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
ALTER TABLE "LlmModelCost"
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
ALTER TABLE "LlmModelMigration"
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);

View File

@@ -1304,3 +1304,164 @@ model OAuthRefreshToken {
@@index([userId, applicationId])
@@index([expiresAt]) // For cleanup
}
// ============================================================================
// LLM Registry Models
// ============================================================================
enum LlmCostUnit {
RUN
TOKENS
}
model LlmProvider {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique
displayName String
description String?
defaultCredentialProvider String?
defaultCredentialId String?
defaultCredentialType String?
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModel {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
slug String @unique
displayName String
description String?
providerId String
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
creatorId String?
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
contextWindow Int
maxOutputTokens Int?
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
isEnabled Boolean @default(true)
isRecommended Boolean @default(false)
// Model-specific capabilities
// These vary per model even within the same provider (e.g., Hugging Face)
// Default to false for safety - partially-seeded rows should not be assumed capable
supportsTools Boolean @default(false)
supportsJsonOutput Boolean @default(false)
supportsReasoning Boolean @default(false)
supportsParallelToolCalls Boolean @default(false)
capabilities Json @default("{}")
metadata Json @default("{}")
Costs LlmModelCost[]
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
@@index([providerId, isEnabled])
@@index([creatorId])
// Note: slug already has @unique which creates an implicit index
}
model LlmModelCost {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
unit LlmCostUnit @default(RUN)
creditCost Int // DB constraint: >= 0
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
// Used to determine which credential system provides the API key.
// Allows different pricing for:
// - Default provider costs (WHERE credentialId IS NULL)
// - User's own API key costs (WHERE credentialId IS NOT NULL)
credentialProvider String
credentialId String?
credentialType String?
currency String?
metadata Json @default("{}")
llmModelId String
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
// Note: Unique constraints are implemented as partial indexes in migration SQL:
// - One for default costs (WHERE credentialId IS NULL)
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
// This allows both provider-level defaults and credential-specific overrides
}
model LlmModelCreator {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique // e.g., "openai", "anthropic", "meta"
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
description String?
websiteUrl String? // Link to creator's website
logoUrl String? // URL to creator's logo
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModelMigration {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
sourceModelSlug String // The original model that was disabled
targetModelSlug String // The model workflows were migrated to
reason String? // Why the migration happened (e.g., "Provider outage")
// FK constraints ensure slugs reference valid models
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict)
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict)
// Track affected nodes as JSON array of node IDs
// Format: ["node-uuid-1", "node-uuid-2", ...]
migratedNodeIds Json @default("[]")
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
// Custom pricing override for migrated workflows during the migration period.
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
// to avoid billing surprises, or offer a discount during the transition.
//
// IMPORTANT: This field is intended for integration with the billing system.
// When billing calculates costs for nodes affected by this migration, it should
// check if customCreditCost is set and use it instead of the target model's cost.
// If null, the target model's normal cost applies.
//
// TODO: Integrate with billing system to apply this override during cost calculation.
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
// For token-priced models, this may be ambiguous. Consider migrating to a relation
// with LlmModelCost or a dedicated override model in a follow-up PR.
customCreditCost Int? // DB constraint: >= 0 when not null
// Revert tracking
isReverted Boolean @default(false)
revertedAt DateTime?
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
// UNIQUE (sourceModelSlug) WHERE isReverted = false
@@index([targetModelSlug])
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
}

View File

@@ -1,36 +0,0 @@
"use server";
import BackendAPI from "@/lib/autogpt-server-api/client";
import { OttoQuery, OttoResponse } from "@/lib/autogpt-server-api/types";
const api = new BackendAPI();
export async function askOtto(
query: string,
conversationHistory: { query: string; response: string }[],
includeGraphData: boolean,
graphId?: string,
): Promise<OttoResponse> {
const messageId = `${Date.now()}-web`;
const ottoQuery: OttoQuery = {
query,
conversation_history: conversationHistory,
message_id: messageId,
include_graph_data: includeGraphData,
graph_id: graphId,
};
try {
const response = await api.askOtto(ottoQuery);
return response;
} catch (error) {
console.error("Error in askOtto server action:", error);
return {
answer: error instanceof Error ? error.message : "Unknown error occurred",
documents: [],
success: false,
error: true,
};
}
}

View File

@@ -23,6 +23,12 @@ import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
import { SubAgentUpdateFeature } from "./components/SubAgentUpdate/SubAgentUpdateFeature";
import { useCustomNode } from "./useCustomNode";
function hasAdvancedFields(schema: RJSFSchema): boolean {
const properties = schema?.properties;
if (!properties) return false;
return Object.values(properties).some((prop: any) => prop.advanced === true);
}
export type CustomNodeData = {
hardcodedValues: {
[key: string]: any;
@@ -108,7 +114,11 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
)}
showHandles={showHandles}
/>
<NodeAdvancedToggle nodeId={nodeId} />
<NodeAdvancedToggle
nodeId={nodeId}
isLastSection={data.uiType === BlockUIType.OUTPUT}
hasAdvancedFields={hasAdvancedFields(inputSchema)}
/>
{data.uiType != BlockUIType.OUTPUT && (
<OutputHandler
uiType={data.uiType}

View File

@@ -2,18 +2,33 @@ import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { CaretDownIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
type Props = {
nodeId: string;
isLastSection?: boolean;
hasAdvancedFields?: boolean;
};
export function NodeAdvancedToggle({ nodeId }: Props) {
export function NodeAdvancedToggle({
nodeId,
isLastSection,
hasAdvancedFields = true,
}: Props) {
const showAdvanced = useNodeStore(
(state) => state.nodeAdvancedStates[nodeId] || false,
);
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
if (!hasAdvancedFields) return null;
return (
<div className="flex items-center justify-start gap-2 bg-white px-5 pb-3.5">
<div
className={cn(
"flex items-center justify-start gap-2 bg-white px-5 pb-3.5",
isLastSection && "rounded-b-xlarge",
)}
>
<Button
variant="ghost"
className="h-fit min-w-0 p-0 hover:border-transparent hover:bg-transparent"

View File

@@ -1,6 +1,6 @@
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { CaretDownIcon, InfoIcon } from "@phosphor-icons/react";
import { CaretDownIcon, CaretRightIcon, InfoIcon } from "@phosphor-icons/react";
import { RJSFSchema } from "@rjsf/utils";
import { useState } from "react";
@@ -30,13 +30,41 @@ export const OutputHandler = ({
const properties = outputSchema?.properties || {};
const [isOutputVisible, setIsOutputVisible] = useState(true);
const brokenOutputs = useBrokenOutputs(nodeId);
const [expandedObjects, setExpandedObjects] = useState<
Record<string, boolean>
>({});
const showHandles = uiType !== BlockUIType.OUTPUT;
function toggleObjectExpanded(key: string) {
setExpandedObjects((prev) => ({ ...prev, [key]: !prev[key] }));
}
function hasConnectedOrBrokenDescendant(
schema: RJSFSchema,
keyPrefix: string,
): boolean {
if (!schema) return false;
return Object.entries(schema).some(
([key, fieldSchema]: [string, RJSFSchema]) => {
const fullKey = keyPrefix ? `${keyPrefix}_#_${key}` : key;
if (isOutputConnected(nodeId, fullKey) || brokenOutputs.has(fullKey))
return true;
if (fieldSchema?.properties)
return hasConnectedOrBrokenDescendant(
fieldSchema.properties,
fullKey,
);
return false;
},
);
}
const renderOutputHandles = (
schema: RJSFSchema,
keyPrefix: string = "",
titlePrefix: string = "",
connectedOnly: boolean = false,
): React.ReactNode[] => {
return Object.entries(schema).map(
([key, fieldSchema]: [string, RJSFSchema]) => {
@@ -44,10 +72,23 @@ export const OutputHandler = ({
const fieldTitle = titlePrefix + (fieldSchema?.title || key);
const isConnected = isOutputConnected(nodeId, fullKey);
const shouldShow = isConnected || isOutputVisible;
const isBroken = brokenOutputs.has(fullKey);
const hasNestedProperties = !!fieldSchema?.properties;
const selfIsRelevant = isConnected || isBroken;
const descendantIsRelevant =
hasNestedProperties &&
hasConnectedOrBrokenDescendant(fieldSchema.properties!, fullKey);
const shouldShow = connectedOnly
? selfIsRelevant || descendantIsRelevant
: isOutputVisible || selfIsRelevant || descendantIsRelevant;
const { displayType, colorClass, hexColor } =
getTypeDisplayInfo(fieldSchema);
const isBroken = brokenOutputs.has(fullKey);
const isExpanded = expandedObjects[fullKey] ?? false;
// User expanded → show all children; auto-expanded → filter to connected only
const shouldRenderChildren = isExpanded || descendantIsRelevant;
return shouldShow ? (
<div
@@ -56,6 +97,19 @@ export const OutputHandler = ({
data-tutorial-id={`output-handler-${nodeId}-${fieldTitle}`}
>
<div className="relative flex items-center gap-2">
{hasNestedProperties && (
<button
onClick={() => toggleObjectExpanded(fullKey)}
className="flex items-center text-slate-500 hover:text-slate-700"
aria-label={isExpanded ? "Collapse" : "Expand"}
>
{isExpanded ? (
<CaretDownIcon size={12} weight="bold" />
) : (
<CaretRightIcon size={12} weight="bold" />
)}
</button>
)}
{fieldSchema?.description && (
<TooltipProvider>
<Tooltip>
@@ -102,12 +156,14 @@ export const OutputHandler = ({
)}
</div>
{/* Recursively render nested properties */}
{fieldSchema?.properties &&
{/* Nested properties */}
{hasNestedProperties &&
shouldRenderChildren &&
renderOutputHandles(
fieldSchema.properties,
fieldSchema.properties!,
fullKey,
`${fieldTitle}.`,
"",
!isExpanded,
)}
</div>
) : null;
@@ -136,7 +192,7 @@ export const OutputHandler = ({
</Button>
<div className="flex flex-col items-end gap-2">
{renderOutputHandles(properties)}
{renderOutputHandles(properties, "", "", !isOutputVisible)}
</div>
</div>
);

View File

@@ -25,34 +25,60 @@ type HistoryStore = {
const MAX_HISTORY = 50;
// Microtask batching state — kept outside the store to avoid triggering
// re-renders. When multiple pushState calls happen in the same synchronous
// execution (e.g. node deletion cascading to edge cleanup), only the first
// (pre-change) state is kept and committed as a single history entry.
let pendingState: HistoryState | null = null;
let batchScheduled = false;
export const useHistoryStore = create<HistoryStore>((set, get) => ({
past: [{ nodes: [], edges: [] }],
future: [],
pushState: (state: HistoryState) => {
const { past } = get();
const lastState = past[past.length - 1];
if (lastState && isEqual(lastState, state)) {
return;
// Keep only the first state within a microtask batch — it represents
// the true pre-change snapshot before any cascading mutations.
if (!pendingState) {
pendingState = state;
}
const actualCurrentState = {
nodes: useNodeStore.getState().nodes,
edges: useEdgeStore.getState().edges,
};
if (!batchScheduled) {
batchScheduled = true;
queueMicrotask(() => {
const stateToCommit = pendingState;
pendingState = null;
batchScheduled = false;
if (isEqual(state, actualCurrentState)) {
return;
if (!stateToCommit) return;
const { past } = get();
const lastState = past[past.length - 1];
if (lastState && isEqual(lastState, stateToCommit)) {
return;
}
const actualCurrentState = {
nodes: useNodeStore.getState().nodes,
edges: useEdgeStore.getState().edges,
};
if (isEqual(stateToCommit, actualCurrentState)) {
return;
}
set((prev) => ({
past: [...prev.past.slice(-MAX_HISTORY + 1), stateToCommit],
future: [],
}));
});
}
set((prev) => ({
past: [...prev.past.slice(-MAX_HISTORY + 1), state],
future: [],
}));
},
initializeHistory: () => {
pendingState = null;
const currentNodes = useNodeStore.getState().nodes;
const currentEdges = useEdgeStore.getState().edges;
@@ -122,5 +148,8 @@ export const useHistoryStore = create<HistoryStore>((set, get) => ({
},
canRedo: () => get().future.length > 0,
clear: () => set({ past: [{ nodes: [], edges: [] }], future: [] }),
clear: () => {
pendingState = null;
set({ past: [{ nodes: [], edges: [] }], future: [] });
},
}));

File diff suppressed because it is too large Load Diff

View File

@@ -34,10 +34,7 @@ export function FormRenderer({
}, [preprocessedSchema, uiSchema]);
return (
<div
className={cn("mb-6 mt-4", className)}
data-tutorial-id="input-handles"
>
<div className={cn("mt-4", className)} data-tutorial-id="input-handles">
<Form
formContext={formContext}
idPrefix="agpt"

View File

@@ -63,7 +63,6 @@ export const useAnyOfField = (props: FieldProps) => {
);
const handlePrefix = cleanUpHandleId(field_id);
console.log("handlePrefix", handlePrefix);
useEdgeStore
.getState()
.removeEdgesByHandlePrefix(registry.formContext.nodeId, handlePrefix);

View File

@@ -4,6 +4,7 @@ import {
TemplatesType,
} from "@rjsf/utils";
import { AnyOfField } from "./anyof/AnyOfField";
import { OneOfField } from "./oneof/OneOfField";
import {
ArrayFieldItemTemplate,
ArrayFieldTemplate,
@@ -32,6 +33,7 @@ const NoButton = () => null;
export function generateBaseFields(): RegistryFieldsType {
return {
AnyOfField,
OneOfField,
ArraySchemaField,
};
}

View File

@@ -0,0 +1,243 @@
import {
descriptionId,
FieldProps,
getTemplate,
getUiOptions,
getWidget,
} from "@rjsf/utils";
import { useEffect, useRef, useState } from "react";
import { AnyOfField } from "../anyof/AnyOfField";
import { cleanUpHandleId, getHandleId, updateUiOption } from "../../helpers";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { ANY_OF_FLAG } from "../../constants";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
function getDiscriminatorPropName(schema: any): string | undefined {
if (!schema?.discriminator) return undefined;
if (typeof schema.discriminator === "string") return schema.discriminator;
return schema.discriminator.propertyName;
}
export function OneOfField(props: FieldProps) {
const { schema } = props;
const discriminatorProp = getDiscriminatorPropName(schema);
if (!discriminatorProp) {
return <AnyOfField {...props} />;
}
return (
<DiscriminatedUnionField {...props} discriminatorProp={discriminatorProp} />
);
}
interface DiscriminatedUnionFieldProps extends FieldProps {
discriminatorProp: string;
}
function DiscriminatedUnionField({
discriminatorProp,
...props
}: DiscriminatedUnionFieldProps) {
const { schema, registry, formData, onChange, name } = props;
const { fields, schemaUtils, formContext } = registry;
const { SchemaField } = fields;
const { nodeId } = formContext;
const field_id = props.fieldPathId.$id;
// Resolve variant schemas from $refs
const variants = useRef(
(schema.oneOf || []).map((opt: any) =>
schemaUtils.retrieveSchema(opt, formData),
),
);
// Build dropdown options from variant titles and discriminator const values
const enumOptions = variants.current.map((variant: any, index: number) => {
const discValue = (variant.properties?.[discriminatorProp] as any)?.const;
return {
value: index,
label: variant.title || discValue || `Option ${index + 1}`,
discriminatorValue: discValue,
};
});
// Determine initial selected index from formData
function getInitialIndex() {
const currentDisc = formData?.[discriminatorProp];
if (currentDisc) {
const idx = enumOptions.findIndex(
(o) => o.discriminatorValue === currentDisc,
);
if (idx >= 0) return idx;
}
return 0;
}
const [selectedIndex, setSelectedIndex] = useState(getInitialIndex);
// Generate handleId for sub-fields (same convention as AnyOfField)
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
const handleId = getHandleId({
uiOptions,
id: field_id + ANY_OF_FLAG,
schema,
});
const childUiSchema = updateUiOption(props.uiSchema, {
handleId,
label: false,
fromAnyOf: true,
});
// Get selected variant schema with discriminator property filtered out
// and sub-fields inheriting the parent's advanced value
const selectedVariant = variants.current[selectedIndex];
const parentAdvanced = (schema as any).advanced;
function getFilteredSchema() {
if (!selectedVariant?.properties) return selectedVariant;
const filteredProperties: Record<string, any> = {};
for (const [key, value] of Object.entries(selectedVariant.properties)) {
if (key === discriminatorProp) continue;
filteredProperties[key] =
parentAdvanced !== undefined
? { ...(value as any), advanced: parentAdvanced }
: value;
}
return {
...selectedVariant,
properties: filteredProperties,
required: (selectedVariant.required || []).filter(
(r: string) => r !== discriminatorProp,
),
};
}
const filteredSchema = getFilteredSchema();
// Handle variant change
function handleVariantChange(option?: string) {
const newIndex = option !== undefined ? parseInt(option, 10) : -1;
if (newIndex === selectedIndex || newIndex < 0) return;
const newVariant = variants.current[newIndex];
const oldVariant = variants.current[selectedIndex];
const discValue = (newVariant.properties?.[discriminatorProp] as any)
?.const;
// Clean edges for this field
const handlePrefix = cleanUpHandleId(field_id);
useEdgeStore.getState().removeEdgesByHandlePrefix(nodeId, handlePrefix);
// Sanitize current data against old→new schema to preserve shared fields
let newFormData = schemaUtils.sanitizeDataForNewSchema(
newVariant,
oldVariant,
formData,
);
// Fill in defaults for the new variant
newFormData = schemaUtils.getDefaultFormState(
newVariant,
newFormData,
"excludeObjectChildren",
) as any;
newFormData = { ...newFormData, [discriminatorProp]: discValue };
setSelectedIndex(newIndex);
onChange(newFormData, props.fieldPathId.path, undefined, field_id);
}
// Sync selectedIndex when formData discriminator changes externally
// (e.g. undo/redo, loading saved state)
const currentDiscValue = formData?.[discriminatorProp];
useEffect(() => {
const idx = currentDiscValue
? enumOptions.findIndex((o) => o.discriminatorValue === currentDiscValue)
: -1;
if (idx >= 0) {
if (idx !== selectedIndex) setSelectedIndex(idx);
} else if (enumOptions.length > 0 && selectedIndex !== 0) {
// Unknown or cleared discriminator — full reset via same cleanup path
handleVariantChange("0");
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentDiscValue]);
// Auto-set discriminator on initial render if missing
useEffect(() => {
const discValue = enumOptions[selectedIndex]?.discriminatorValue;
if (discValue && formData?.[discriminatorProp] !== discValue) {
onChange(
{ ...formData, [discriminatorProp]: discValue },
props.fieldPathId.path,
undefined,
field_id,
);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
const selector = (
<Widget
id={field_id}
name={`${name}__oneof_select`}
schema={{ type: "number", default: 0 }}
onChange={handleVariantChange}
onBlur={props.onBlur}
onFocus={props.onFocus}
disabled={props.disabled || enumOptions.length === 0}
multiple={false}
value={selectedIndex}
options={{ enumOptions }}
registry={registry}
placeholder={props.placeholder}
autocomplete={props.autocomplete}
className={cn("-ml-1 h-[22px] w-fit gap-1 px-1 pl-2 text-xs font-medium")}
autofocus={props.autofocus}
label=""
hideLabel={true}
readonly={props.readonly}
/>
);
const DescriptionFieldTemplate = getTemplate(
"DescriptionFieldTemplate",
registry,
uiOptions,
);
const description_id = descriptionId(props.fieldPathId ?? "");
return (
<div>
<div className="flex items-center gap-2">
<Text variant="body" className="line-clamp-1">
{schema.title || name}
</Text>
<Text variant="small" className="mr-1 text-red-500">
{props.required ? "*" : null}
</Text>
{selector}
<DescriptionFieldTemplate
id={description_id}
description={schema.description || ""}
schema={schema}
registry={registry}
/>
</div>
{filteredSchema && filteredSchema.type !== "null" && (
<SchemaField
{...props}
schema={filteredSchema}
uiSchema={childUiSchema}
/>
)}
</div>
);
}

View File

@@ -6,7 +6,11 @@ import {
titleId,
} from "@rjsf/utils";
import { isAnyOfChild, isAnyOfSchema } from "../../utils/schema-utils";
import {
isAnyOfChild,
isAnyOfSchema,
isOneOfSchema,
} from "../../utils/schema-utils";
import {
cleanUpHandleId,
getHandleId,
@@ -82,12 +86,13 @@ export default function FieldTemplate(props: FieldTemplateProps) {
const shouldDisplayLabel =
displayLabel ||
(schema.type === "boolean" && !isAnyOfChild(uiSchema as any));
const shouldShowTitleSection = !isAnyOfSchema(schema) && !additional;
const isUnionSchema = isAnyOfSchema(schema) || isOneOfSchema(schema);
const shouldShowTitleSection = !isUnionSchema && !additional;
const shouldShowChildren =
schema.type === "object" ||
schema.type === "array" ||
isAnyOfSchema(schema) ||
isUnionSchema ||
!isHandleConnected;
const isAdvancedField = (schema as any).advanced === true;
@@ -95,8 +100,7 @@ export default function FieldTemplate(props: FieldTemplateProps) {
return null;
}
const marginBottom =
isPartOfAnyOf({ uiOptions }) || isAnyOfSchema(schema) ? 0 : 16;
const marginBottom = isPartOfAnyOf({ uiOptions }) || isUnionSchema ? 0 : 16;
return (
<WrapIfAdditionalTemplate

View File

@@ -7,7 +7,7 @@ import {
import { Text } from "@/components/atoms/Text/Text";
import { getTypeDisplayInfo } from "@/app/(platform)/build/components/FlowEditor/nodes/helpers";
import { isAnyOfSchema } from "../../utils/schema-utils";
import { isAnyOfSchema, isOneOfSchema } from "../../utils/schema-utils";
import { cn } from "@/lib/utils";
import { cleanUpHandleId, isArrayItem } from "../../helpers";
import { InputNodeHandle } from "@/app/(platform)/build/components/FlowEditor/handlers/NodeHandle";
@@ -18,7 +18,7 @@ export default function TitleField(props: TitleFieldProps) {
const { nodeId, showHandles } = registry.formContext;
const uiOptions = getUiOptions(uiSchema);
const isAnyOf = isAnyOfSchema(schema);
const isAnyOf = isAnyOfSchema(schema) || isOneOfSchema(schema);
const { displayType, colorClass } = getTypeDisplayInfo(schema);
const description_id = descriptionId(id);

View File

@@ -8,6 +8,14 @@ export function isAnyOfSchema(schema: RJSFSchema | undefined): boolean {
);
}
export function isOneOfSchema(schema: RJSFSchema | undefined): boolean {
return (
Array.isArray(schema?.oneOf) &&
schema!.oneOf.length > 0 &&
schema?.enum === undefined
);
}
export const isAnyOfChild = (
uiSchema: UiSchema<any, RJSFSchema, any> | undefined,
): boolean => {