Compare commits

..

96 Commits

Author SHA1 Message Date
Zamil Majdy
60c6c14211 fix(chat/tools): remove hard output truncation from sandbox and web_fetch
The SDK already handles oversized tool output by writing to tool-results
files and reading back via MCP. Our 50K char truncation was cutting off
output before the agent could see it — the SDK's mechanism is the proper
way to handle large results.
2026-02-13 16:56:05 +04:00
Zamil Majdy
a86878ae6f fix(chat/sdk): align read_transcript_file min lines with validate_transcript
Both now accept >=2 lines (a valid 1-turn transcript is user+assistant).
2026-02-13 16:39:32 +04:00
Zamil Majdy
80e413b969 fix(chat/sdk): address PR #12103 review comments
- Remove duplicate security check in _cleanup_sdk_tool_results (copy-paste)
- Don't delete transcript on transient errors — only the current
  turn failed, the transcript is still valid for future resume
- Add post-construction realpath check in write_transcript_to_tempfile
  to satisfy CodeQL taint analysis
2026-02-13 16:38:03 +04:00
Zamil Majdy
52c8a25531 fix(chat/sdk): fix transcript validation and type captured_transcript properly
- Replace dict[str,str] with CapturedTranscript dataclass for type safety
- Fix validate_transcript requiring >=3 lines — after stripping metadata,
  a valid 1-turn conversation is just user+assistant (2 lines)
- Apply CodeQL autofix: internalize max_len in _sanitize_id, add fallback
2026-02-13 16:32:06 +04:00
Zamil Majdy
d0f0c32e70 fix(chat/sdk): validate cwd against sandbox prefix to fix CodeQL alert
CodeQL traces session_id → cwd → os.makedirs/open as uncontrolled path.
Add realpath + startswith check against /tmp/copilot- prefix directly in
write_transcript_to_tempfile so CodeQL recognizes the sanitization.

Also resolve the prefix with realpath for macOS where /tmp → /private/tmp.
2026-02-13 15:49:30 +04:00
Zamil Majdy
8dfd0a77a0 fix(chat/sdk): sanitize IDs in transcript paths to fix CodeQL alert
Add _sanitize_id() that strips non-hex characters from session/user IDs
before using them in file paths. Also add realpath containment check in
write_transcript_to_tempfile as defence-in-depth.
2026-02-13 15:44:29 +04:00
Zamil Majdy
4bfd6c8870 fix(chat/sdk): address additional PR review feedback
- transcript: compare bytes-to-bytes in size guard (not str vs bytes)
- service: move user message preview from INFO to DEBUG level (PII)
2026-02-13 15:41:45 +04:00
Zamil Majdy
1918828405 fix(chat/sdk): flatten transcript storage path, remove duplicate session_id
Before: chat-transcripts/{user_id}/{session_id}/{session_id}.jsonl
After:  chat-transcripts/{user_id}/{session_id}.jsonl
2026-02-13 15:39:43 +04:00
Zamil Majdy
9c855b501b fix(chat/sdk): address PR review feedback on security and robustness
- security_hooks: use realpath instead of normpath to resolve symlinks
- security_hooks: check tool-results as path segment, not substring
- response_adapter: emit StreamFinish for unknown ResultMessage subtypes
- tool_adapter: delete file after read (prevent accumulation in pods)
- check_operation_status: guard against None.strip() from LLM null args
- service: remove redundant ".." check (realpath already resolves)
2026-02-13 15:37:22 +04:00
Zamil Majdy
5c9d0577c0 poetry lock 2026-02-13 15:34:01 +04:00
Zamil Majdy
a79bd88e7c feat(chat/sdk): move transcript storage from DB column to bucket
Replace the sdkTranscript TEXT column with WorkspaceStorageBackend
(GCS/local) for persisting Claude Code JSONL transcripts.  This removes
the implicit 512KB cap that caused --resume to degrade after a few
tool-heavy turns (JSONL is append-only and never shrinks).

Key changes:
- Strip progress/metadata entries before storing (~30% size reduction)
  with parentUuid reparenting for orphaned children
- Upload in background (asyncio.create_task) to avoid blocking SSE
- Size-based conflict guard: never overwrite a larger (newer) transcript
- Validate stripped content before upload
- Log warning when falling back to compression approach
- Enable claude_agent_use_resume by default
- Remove sdkTranscript column from schema, model, and DB layer
- Storage path: chat-transcripts/{user_id}/{session_id}/{session_id}.jsonl
2026-02-13 15:26:53 +04:00
Zamil Majdy
28c1121a8f fix(chat/sdk): block built-in Bash via disallowed_tools and resolve merge conflicts
- Add disallowed_tools=["Bash"] to SDK options so the model never tries
  the built-in Bash tool (previously it tried Bash, got blocked by the
  security hook, then fell back to bash_exec — wasting a turn)
- Resolve merge conflicts in tools/models.py (keep both HEAD additions
  and incoming BlockDetails/BlockDetailsResponse)
- Fix pyright error in find_block.py (pass categories to BlockInfoSummary)
2026-02-13 14:44:42 +04:00
Zamil Majdy
cb3839198c conflict resolve 2026-02-13 14:35:41 +04:00
Zamil Majdy
80804986b0 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-claude-code-continue-session 2026-02-13 14:30:09 +04:00
Zamil Majdy
b915e67a9b feat(chat/sdk): add stateless multi-turn resume via JSONL transcripts
Capture Claude Code CLI session transcripts via the Stop hook and persist
them in the DB. On subsequent turns, write the transcript to a temp file
and pass --resume so the CLI restores full conversation context without
lossy history compression.

Key changes:
- transcript.py: read/write/validate JSONL transcript utilities
- security_hooks: register Stop hook to capture transcript_path
- service.py: resume strategy with fallback to compression
- schema.prisma: add sdkTranscript column to ChatSession
- Feature flag: CLAUDE_AGENT_USE_RESUME (default off)
2026-02-13 13:48:04 +04:00
Zamil Majdy
32e9dda30d fix(chat/sdk): resolve relative paths in security hooks and unify workspace access
The security hook's path validation blocked SDK Read/Write tools because
it didn't resolve relative paths against sdk_cwd. Since the SDK sets cwd,
Claude naturally uses relative paths like "test.txt" which failed the
absolute path prefix check. Now relative paths are joined with sdk_cwd
before validation, and denial messages include the allowed workspace path.

Also clarifies the workspace model: SDK Read/Write + bash_exec share the
same ephemeral session directory, while workspace_file tools provide
persistent cloud storage across sessions.
2026-02-13 10:40:41 +04:00
Zamil Majdy
cb45e7957b feat: fix openapi.json 2026-02-12 23:39:47 +04:00
Zamil Majdy
f1d02fb8f3 fix(chat/sdk): move cwd setup inside try block to ensure cleanup
Move _make_sdk_cwd() and os.makedirs() inside the try block so
the finally cleanup always runs, preventing /tmp dir leaks if
setup fails.
2026-02-12 23:32:26 +04:00
Zamil Majdy
47de6b6420 feat(chat): add check_operation_status tool for long-running ops
Lets the CoPilot agent query whether a create_agent/edit_agent
operation is still running, completed, or failed. Accepts operation_id
or task_id from a previous operation_started response and looks up
the task status in Redis via stream_registry.
2026-02-12 23:30:51 +04:00
Zamil Majdy
62cd2eea89 fix(chat/sandbox): use --symlink for compat paths on Debian 13
On Debian 13 (bookworm+), /bin, /lib, /sbin, /lib64 are symlinks to
/usr/*.  bwrap --ro-bind cannot create a symlink as a mount target
inside the sandbox, causing "execvp: No such file or directory" because
the ELF dynamic linker at /lib64/ld-linux-x86-64.so.2 is unreachable.

Detect symlinks at runtime with os.path.islink() and use bwrap
--symlink instead of --ro-bind.  Falls back to --ro-bind on older
distros where these are real directories.
2026-02-12 22:55:29 +04:00
Zamil Majdy
ae61ec692e Merge branch 'dev' into feat/copitlot-claude-code 2026-02-12 22:27:50 +04:00
Zamil Majdy
9296bd8736 fix(chat/sandbox): fix bwrap inside Docker containers
Three fixes for bubblewrap sandbox:
- Fix --tmpdir (invalid) to --tmpfs (correct bwrap option)
- Add --unshare-user so bwrap can create namespaces inside
  unprivileged Docker containers (no CAP_SYS_ADMIN needed)
- Reorder mounts: --tmpfs /tmp first, then --bind workspace on top,
  so the workspace directory is visible through the fresh tmpfs
2026-02-12 22:22:39 +04:00
Zamil Majdy
308113c03d fix(chat/sdk): remove obsolete Bash allowlist tests
The SDK built-in Bash tool is now unconditionally blocked (bash_exec
MCP tool with bubblewrap is used instead). Remove tests that expected
safe Bash commands to be allowed and replace with a single test that
verifies Bash is always denied.
2026-02-12 22:19:30 +04:00
Zamil Majdy
51abf13254 feat(chat): use LaunchDarkly flag for copilot SDK rollout
Replace static CHAT_USE_CLAUDE_AGENT_SDK env var with a LaunchDarkly
feature flag (copilot-sdk) for per-user rollout control. The env var
value serves as the default when LD is not configured or the flag
doesn't exist yet.
2026-02-12 22:02:28 +04:00
Zamil Majdy
54b03d3a29 fix(frontend): remove python_exec from openapi.json ResponseType enum
The python_exec tool was removed from the backend but the generated
openapi.json still referenced the enum value.
2026-02-12 21:55:25 +04:00
Zamil Majdy
239dff5ebd feat(chat/sandbox): add resource limits to bubblewrap sandbox
Add ulimit-based resource caps inside the bwrap sandbox to prevent
fork bombs and resource exhaustion:
- max 64 processes (stops fork bombs)
- 512 MB virtual memory
- 50 MB max file size
- 256 open file descriptors

Limits are applied via `sh -c 'ulimit ...; exec "$@"'` wrapper inside
the sandbox, so they're inherited by all child processes.
2026-02-12 21:47:49 +04:00
Zamil Majdy
1dd53db21c feat(chat/sandbox): bubblewrap sandbox for bash_exec, remove python_exec
- Replace `--ro-bind / /` with whitelist-only filesystem: only /usr, /etc,
  /bin, /lib, /sbin mounted read-only. /app, /root, /home, /opt, /var are
  completely invisible inside the sandbox.
- Add `--clearenv` to wipe all inherited env vars (API keys, DB passwords).
  Only safe vars (PATH, HOME=workspace, LANG) are explicitly set.
- Remove python_exec tool — bash_exec can run `python3 -c` or heredocs with
  identical bubblewrap protection, reducing attack surface.
- Remove all fallback security code (import hooks, blocked modules, network
  command lists). Tools now hard-require bubblewrap — disabled on platforms
  without bwrap.
- Clean up security_hooks.py: remove ~200 lines of dead bash validation code,
  add Bash to BLOCKED_TOOLS as defence-in-depth.
- Wire up long-running tool callback in SDK service for create_agent/edit_agent
  delegation to Redis Streams background infrastructure.
2026-02-12 21:44:40 +04:00
Zamil Majdy
06c16ee2fe fix(chat/sdk): non-blocking long-running tools, tighten security
- Long-running tools (create_agent) now run in background and return
  immediately with an operation_id. Add check_operation MCP tool for
  polling results. Prevents 3+ min blocking and survives page refresh.
- Fix CodeQL path traversal alert: use normpath+startswith sanitizer
  in _make_sdk_cwd() instead of assert.
- Tighten _read_file_handler: restrict from ~/.claude/ to only
  ~/.claude/projects/**/tool-results/ (sentry review feedback).
- Fix bash redirect bypass: strip quoted strings before checking for
  unquoted > operator, catches `echo hello>file` (sentry review).
2026-02-12 20:39:33 +04:00
Zamil Majdy
8d2a649ee5 refactor(chat/sdk): remove Langfuse tracing — OpenRouter handles observability
Delete tracing.py (~408 lines) and all TracedSession/hook references from the
SDK path. OpenRouter already provides token usage, cost tracking, and request
logging, making manual Langfuse integration redundant. This also fixes the
broken 'Langfuse' object has no attribute 'trace' warning on every request.
2026-02-12 20:24:27 +04:00
Zamil Majdy
9589474709 Merge branch 'dev' into feat/copitlot-claude-code 2026-02-12 19:40:32 +04:00
Zamil Majdy
749a78723a refactor(chat/sdk): deduplicate code and remove anthropic fallback
- Extract shared `make_session_path()` into sandbox.py (single source of
  truth for workspace path sanitization), replace duplicate in service.py
- Delete anthropic_fallback.py (~360 lines) — redundant third code path;
  routes.py already falls back to non-SDK service
- Remove dead `traced_session()`, `get_tool_definitions()`,
  `get_tool_handlers()`, `_current_tool_call_id` ContextVar
- Fix hardcoded model in tracing — pass actual resolved model
- Fix inconsistent model name splitting in anthropic fallback
2026-02-12 19:26:29 +04:00
Zamil Majdy
bec2e1ddee fix(chat/tools): sanitize session_id in sandbox workspace path
Align with SDK's _make_sdk_cwd() to prevent path traversal and ensure
python_exec/bash_exec share the same workspace as SDK file tools.
2026-02-12 19:08:47 +04:00
Zamil Majdy
ec1ab06e0d chore(chat): bump default max_subtasks from 3 to 10 2026-02-12 19:07:42 +04:00
Zamil Majdy
f31cb49557 feat(chat/tools): add sandboxed python_exec, bash_exec, web_fetch tools and enable Task
- Add sandbox.py with network-isolated execution via unshare --net (Linux)
  and import/command blocklist fallback (macOS dev)
- Add python_exec tool: runs Python in subprocess with no network, workspace-scoped
- Add bash_exec tool: full Bash scripting with no network, workspace-scoped
- Add web_fetch tool: SSRF-protected URL fetching via backend Requests utility
- Remove SDK built-in Bash from allowlist (replaced by sandboxed bash_exec)
- Enable SDK built-in Task (sub-agents) with per-session rate limit (default 3)
- Add claude_agent_max_subtasks config field
2026-02-12 19:07:19 +04:00
Zamil Majdy
fd28c386f4 Merge branch 'dev' into feat/copitlot-claude-code 2026-02-12 18:50:11 +04:00
Zamil Majdy
3bea584659 feat(chat/sdk): route SDK through OpenRouter with observability (#12084)
## Summary
- Routes Claude Agent SDK API calls through OpenRouter via
`ANTHROPIC_BASE_URL` / `ANTHROPIC_AUTH_TOKEN` env vars, enabling
per-call token and cost tracking on the OpenRouter dashboard
- Adds `sdk_model` and `sdk_max_budget_usd` config fields for
SDK-specific model selection and budget control
- Emits `StreamUsage` from SDK `ResultMessage` so the frontend receives
token counts, and persists usage to `session.usage`
- Fixes Langfuse tracing to use the configured model name instead of a
hardcoded default
- Updates Anthropic fallback to use `config.api_key` / `config.base_url`
(OpenRouter routing) instead of raw `ANTHROPIC_API_KEY` env var

## Test plan
- [ ] Deploy and send a CoPilot message — verify the API call appears on
the OpenRouter dashboard
- [ ] Check Langfuse trace shows correct model name (e.g.
`claude-opus-4.6` not hardcoded `claude-sonnet-4-20250514`)
- [ ] Verify frontend receives `StreamUsage` with `promptTokens` /
`completionTokens` values
- [ ] Set `CHAT_SDK_MAX_BUDGET_USD` and verify budget is respected
- [ ] Test fallback path (without `claude-agent-sdk` installed) still
works via OpenRouter

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

Routes Claude Agent SDK API calls through OpenRouter for enhanced
observability and cost tracking. The PR enables per-call token tracking
on the OpenRouter dashboard by configuring the SDK to use
`ANTHROPIC_BASE_URL` and `ANTHROPIC_AUTH_TOKEN` environment variables
derived from the chat configuration.

Key changes:
- Added `sdk_model` and `sdk_max_budget_usd` configuration fields for
SDK-specific control
- Implemented automatic model name resolution that strips OpenRouter
provider prefixes
- Updated SDK client initialization to route through OpenRouter with
proper environment variables
- Emits `StreamUsage` events from SDK `ResultMessage` for frontend token
visibility
- Persists usage data to `session.usage` for historical tracking
- Fixed Langfuse tracing to use the configured model name instead of
hardcoded defaults
- Updated fallback path to use OpenRouter routing instead of direct
Anthropic API
</details>


<details><summary><h3>Confidence Score: 4/5</h3></summary>

- Safe to merge with minor observations - the implementation is solid
and the changes are well-structured
- The code quality is high with proper error handling, clear separation
of concerns, and good defensive coding practices. The changes integrate
cleanly with existing patterns. Minor observations include missing
validation for sdk_max_budget_usd and a potential edge case in model
name resolution, but these don't block merging
- No files require special attention - all changes follow existing
patterns and maintain consistency
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant Frontend
    participant Backend
    participant SDK as Claude Agent SDK
    participant OpenRouter
    participant Anthropic
    participant Langfuse

    Frontend->>Backend: POST /chat/completions
    Backend->>Backend: Load config (api_key, base_url)
    Backend->>Backend: Resolve SDK model (strip OpenRouter prefix)
    Backend->>Backend: Build SDK env vars (ANTHROPIC_BASE_URL, ANTHROPIC_AUTH_TOKEN)
    
    Backend->>Langfuse: Initialize TracedSession with model name
    Backend->>SDK: ClaudeSDKClient(model, env, max_budget_usd)
    
    SDK->>SDK: Use ANTHROPIC_BASE_URL from env
    SDK->>OpenRouter: POST /messages (via configured base_url)
    OpenRouter->>Anthropic: Forward request with routing
    Anthropic-->>OpenRouter: Stream response chunks
    OpenRouter-->>SDK: Stream response with usage data
    
    loop For each SDK message
        SDK-->>Backend: AssistantMessage/UserMessage/ResultMessage
        Backend->>Langfuse: log_sdk_message()
        Backend->>Backend: SDKResponseAdapter.convert_message()
        Backend->>Backend: Extract usage from ResultMessage
        Backend->>Backend: Persist Usage to session.usage
        Backend-->>Frontend: StreamUsage(promptTokens, completionTokens)
        Backend-->>Frontend: StreamTextDelta/StreamToolInput/etc
    end
    
    Backend->>Langfuse: Log final generation with model name
    Backend->>Backend: Save session with usage data
    Backend-->>Frontend: StreamFinish
```
</details>


<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-12 21:47:39 +07:00
Zamil Majdy
d7f7a2747f fix(backend/chat): Atomic message append to prevent race condition
Replace the read-modify-write pattern in stream_chat_post with an
atomic append_and_save_message helper that acquires the session lock
before re-fetching and appending. This prevents message loss when
concurrent requests modify the same session.
2026-02-12 09:10:43 +04:00
Zamil Majdy
68849e197c format 2026-02-12 08:26:26 +04:00
Zamil Majdy
211478bb29 Revert "style: run ruff format and isort"
This reverts commit 40b58807ab.
2026-02-12 08:25:22 +04:00
Zamil Majdy
0e88dd15b2 feat(chat): add hook-based tracing integration for Claude Agent SDK
- Add create_tracing_hooks() for fine-grained tool timing
- Add merge_hooks() utility to combine security + tracing hooks
- Captures precise pre/post timing for tool executions
- Tracks tool failures via PostToolUseFailure hook
- Integrates seamlessly with existing security hooks
2026-02-12 03:35:16 +00:00
Zamil Majdy
7f3c227f0a feat(chat): add modular Langfuse tracing for Claude Agent SDK
- Create tracing.py with TracedSession context manager
- Automatically trace user messages, SDK messages, and results
- Capture tool calls with input/output and timing
- Log usage and cost from SDK ResultMessage
- No-op when Langfuse not configured (zero overhead)
- Clean integration into service.py via context manager
2026-02-12 03:33:37 +00:00
Zamil Majdy
40b58807ab style: run ruff format and isort 2026-02-12 03:25:19 +00:00
Zamil Majdy
d0e2e6f013 security(service): strengthen path validation for SDK cleanup
- Add empty check after session_id sanitization
- Add assertion for defense-in-depth
- Add explicit '..' traversal check in cleanup
- Replace glob with os.listdir to avoid glob injection
- Add validation that project_dir stays under ~/.claude/projects
- Add warning logs for rejected paths

Addresses CodeQL alert about uncontrolled data in path expression
2026-02-12 03:07:08 +00:00
Zamil Majdy
efdc8d73cc fix(security_hooks): use json.dumps for pattern matching and log warning
- Use json.dumps instead of str() for more predictable pattern matching
- Log warning when SDK not available and security hooks are disabled

Addresses CodeRabbit review feedback
2026-02-12 02:55:04 +00:00
Zamil Majdy
a34810d8a2 revert: remove Bash command extraction from GenericTool
Keep it simple - just show 'Bash completed' instead of special
handling to extract command names like 'jq completed'
2026-02-12 02:53:37 +00:00
Zamil Majdy
038b7d5841 feat(copilot): show specific command name for Bash tool
- Extract command name (jq, grep, etc.) from Bash tool input
- Display 'jq completed' instead of 'Bash completed'
- Add ripgrep and tree to Dockerfile (match ALLOWED_BASH_COMMANDS)
2026-02-12 02:48:19 +00:00
Zamil Majdy
cac93b0cc9 fix(chat): increase SDK buffer limit and add jq
- Add sdk_max_buffer_size config option (default 10MB, was 1MB)
- Pass max_buffer_size to ClaudeAgentOptions to prevent crashes on large tool outputs
- Install jq in Dockerfile for JSON processing capabilities

Fixes AUTOGPT-SERVER-7V2
2026-02-12 02:41:12 +00:00
Zamil Majdy
2025aaf5f2 fix(backend/chat): Preserve full MCP tool output for frontend widgets
The SDK CLI truncates large tool results (writing them to disk),
which breaks frontend widget rendering (e.g., find_block's block
list cards). Stash the full MCP tool output before the SDK sees it,
then use the stash in the response adapter so the frontend always
receives the complete JSON for proper widget parsing.
2026-02-11 23:13:42 +04:00
Zamil Majdy
ae9bce3bae feat(backend/chat): Add sandboxed Bash and notify SDK of restrictions
- Allow Bash tool with command allowlist (jq, grep, head, tail, etc.)
  validated via shlex.split for proper quote handling
- Add workspace path validation for Bash absolute paths
- Add SDK built-in tools (Read/Write/Edit/Glob/Grep/Bash) to allowed_tools
- Append Bash restrictions to system prompt (SDK doesn't know our allowlist)
- Add default_factory to BlockInfoSummary schema fields
- Add 12 Bash sandbox tests covering safe/dangerous commands, substitution,
  redirection, /dev/ access, path escaping
2026-02-11 22:35:39 +04:00
Zamil Majdy
3107d889fc feat(frontend/copilot): Add generic tool widget for unrecognized tools
SDK built-in tools (Read, Glob, Grep, etc.) have no dedicated frontend
widget, so tool calls silently disappeared. Add a GenericTool component
that shows a spinning gear + "Running {tool}…" for any tool-* part
type that doesn't match a known case.
2026-02-11 22:08:03 +04:00
Zamil Majdy
f174fb6303 fix(backend/chat): Strip MCP prefix from SDK tool names for frontend rendering
The Vercel AI SDK frontend renders tool widgets based on tool name
(e.g. "tool-find_block", "tool-run_agent"). The SDK sends tool names
with the MCP prefix (mcp__copilot__find_block) which didn't match
any frontend switch case, causing tool execution to be invisible.

Strip the mcp__copilot__ prefix in the response adapter so tool events
reach the correct frontend widget handlers.
2026-02-11 22:01:59 +04:00
Zamil Majdy
920a4c5f15 feat(backend/chat): Allow Read/Write/Edit/Glob/Grep in SDK within workspace
Move these tools from fully-blocked to workspace-scoped: they are now
allowed when the file path stays within the SDK working directory
(/tmp/copilot-<session>/) or the tool-results directory
(~/.claude/projects/…/tool-results/). This enables the SDK's built-in
oversized tool result handling and workspace file operations.

- Add _validate_workspace_path() with normpath-based path validation
- Pass sdk_cwd from service.py into create_security_hooks()
- Add 20 unit tests covering allowed/denied paths, traversal attacks
2026-02-11 20:39:33 +04:00
Zamil Majdy
e95fadbb86 Merge branch 'dev' into feat/copitlot-claude-code 2026-02-11 20:23:56 +04:00
Zamil Majdy
b14b3803ad feat(backend/chat): Add StreamStartStep/StreamFinishStep to SDK adapter
The non-SDK path emits step boundaries (StartStep/FinishStep) around
each LLM turn and tool cycle. The SDK adapter was missing these,
causing the frontend to lack visual step framing for tool calls.

Now the SDK adapter emits:
- StreamStartStep after init and before each new LLM turn
- StreamFinishStep after tool results and before final finish
2026-02-11 20:18:27 +04:00
Zamil Majdy
82c483d6c8 Merge branch 'dev' into feat/copitlot-claude-code 2026-02-11 07:17:38 +04:00
Zamil Majdy
7cffa1895f fix(backend/chat): Filter duplicate StreamStart from non-SDK path
Routes.py already publishes a StreamStart before calling the service.
The SDK path filters the duplicate internally, but the non-SDK path
did not, causing two StreamStart events to reach the frontend.
2026-02-11 06:52:47 +04:00
Zamil Majdy
9791bdd724 fix(backend/chat): Use normpath+startswith pattern for CodeQL path sanitization
CodeQL doesn't recognize re.sub as a path sanitizer. Switch to the
os.path.normpath + startswith prefix check pattern that CodeQL's
taint model explicitly recognizes as breaking the taint chain.
2026-02-11 06:45:12 +04:00
Zamil Majdy
750a674c78 fix lock 2026-02-11 06:39:03 +04:00
Zamil Majdy
960c7980a3 fix(backend/chat): Use named helper for session_id sanitization to satisfy CodeQL
Replace inline comprehension with _sanitize_session_id() using re.sub
so CodeQL recognizes the path-traversal sanitization barrier.
2026-02-11 06:32:16 +04:00
Zamil Majdy
e85d437bb2 fix(backend/chat): Sanitize session_id in SDK cwd path to prevent path traversal 2026-02-11 06:26:48 +04:00
Zamil Majdy
44f9536bd6 fix lock 2026-02-11 06:24:41 +04:00
Zamil Majdy
1c1085a227 Merge remote-tracking branch 'origin/dev' into feat/copitlot-claude-code
# Conflicts:
#	autogpt_platform/backend/backend/api/features/chat/config.py
#	autogpt_platform/backend/poetry.lock
2026-02-11 05:30:46 +04:00
Zamil Majdy
d7ef70469e fix(backend/chat): Fix cleanup race condition and move to outer finally
- Use session-specific temp dir (/tmp/copilot-{session_id}) as SDK cwd
  to prevent concurrent sessions from deleting each other's tool-result
  files during cleanup
- Move _cleanup_sdk_tool_results() to outer finally block so it runs
  even when the outer except Exception fires
- Clean up the temp cwd directory after each session
- Remove unnecessary inner try/finally nesting
2026-02-11 05:13:02 +04:00
Zamil Majdy
1926127ddd fix(backend/chat): Fix bugs and remove dead code in SDK integration
- Fix message accumulation bug: reset has_appended_assistant when
  creating new post-tool assistant message to prevent lost text deltas
- Fix hardcoded model in anthropic_fallback.py: use config.model instead
  of hardcoded "claude-sonnet-4-20250514"
- Fix _SDK_TOOL_RESULTS_DIR using hardcoded /root/ path: use expanduser
- Remove unused create_strict_security_hooks (~75 lines)
- Remove unused create_heartbeat/create_usage from response adapter
- Remove unused RAW_TOOL_NAMES from tool_adapter
- Extract _MAX_TOOL_ITERATIONS constant from magic number
2026-02-11 04:42:05 +04:00
Zamil Majdy
8b509e56de refactor(backend/chat): Replace --resume with conversation context, add compaction and dedup
- Remove broken --resume/session file approach (CLI v2.1.38 can't load
  >2 message session files) and delete session_file.py + tests
- Embed prior conversation turns as <conversation_history> context in
  the user message for multi-turn memory
- Add context compaction using shared compress_context() from prompt.py
  with LLM summarization + truncation fallback for long conversations
- Reuse _build_system_prompt and _generate_session_title from parent
  service.py instead of duplicating (gains Langfuse prompt support)
- Add has_conversation_history param to _build_system_prompt to avoid
  greeting on multi-turn conversations
- Fix _SDK_TOOL_RESULTS_GLOB from hardcoded /root/ to expanduser ~/
2026-02-11 04:22:11 +04:00
Zamil Majdy
acb2d0bd1b fix(backend/chat): Resolve symlinks in session file path for --resume
The CLI resolves symlinks when computing its project directory (e.g.
/tmp -> /private/tmp on macOS), so our session file writes must use
the resolved path to match. Also adds cwd to ClaudeAgentOptions and
debug logging for SDK messages.
2026-02-10 20:11:16 +04:00
Zamil Majdy
51aa369c80 fix(backend): Restore PyYAML cp38 wheel entries in poetry.lock
Re-add Python 3.8 wheel entries for PyYAML that were dropped by
poetry lock resolution, keeping the lockfile consistent with dev.
2026-02-10 20:06:45 +04:00
Zamil Majdy
6403ffe353 fix(backend/chat): Use --resume with session files for multi-turn conversations
Replace broken AsyncIterable approach (CLI rejects assistant-type stdin
messages) with JSONL session files written to the CLI's storage directory.
This enables --resume to load full user+assistant context with turn-level
compaction support for long conversations.
2026-02-10 18:46:33 +04:00
Zamil Majdy
c40a98ba3c Merge branches 'feat/copitlot-claude-code' and 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copitlot-claude-code 2026-02-10 18:19:23 +04:00
Zamil Majdy
a31fc8b162 refactor(backend/chat): Use proper SDK types and in-memory conversation history
Replace duck typing (class name checks, getattr) with isinstance() using
SDK-exported dataclasses. Replace file-based --resume with AsyncIterable
message injection for conversation history, eliminating disk I/O. Add 15
unit tests for the response adapter.
2026-02-10 18:17:00 +04:00
Zamil Majdy
0f2d1a6553 Merge branch 'dev' into feat/copitlot-claude-code 2026-02-10 17:23:06 +04:00
Zamil Majdy
87d817b83b fix(backend/chat): Allow MCP-registered tools through security hook and fix title generation
- Skip BLOCKED_TOOLS check for tools with mcp__copilot__ prefix since they
  are already sandboxed by tool_adapter (fixes Read tool being blocked)
- Fall back to session.messages for title generation when message=None
2026-02-10 17:15:42 +04:00
Zamil Majdy
acf932bf4f refactor(backend/chat): Move glob/os imports to top-level in SDK service 2026-02-10 16:57:11 +04:00
Zamil Majdy
f562d9a277 fix(backend/chat): Add Read tool for SDK oversized tool results
The Claude Agent SDK saves tool results exceeding its token limit to
files and instructs the agent to read them back with a Read tool. Our
MCP server didn't have this tool, breaking the agent on large results
like run_block output (117K+ chars).

Changes:
- Add a Read tool to the MCP server (restricted to /root/.claude/)
- Register it in COPILOT_TOOL_NAMES so the SDK can use it
- Add safety-net truncation at 500K chars for extreme cases
- Clean up SDK tool-result files after each client session
2026-02-10 16:53:04 +04:00
Zamil Majdy
3c92a96504 fix(backend/chat): Publish StreamError before StreamFinish on error paths
When run_ai_generation() or event_generator() encounter errors, they
were only publishing StreamFinish without a preceding StreamError. The
frontend treats finish-without-error as normal completion, leaving the
user with an apparently stuck/empty response requiring a page refresh.
2026-02-10 15:49:23 +04:00
Zamil Majdy
8b8e1df739 fix(backend/chat): Auto-expire stale running tasks to unblock sessions
Tasks stuck in "running" status beyond stream_timeout (300s) are now
auto-marked as failed when looked up, preventing zombie tasks from
blocking the session indefinitely.
2026-02-10 15:35:43 +04:00
Zamil Majdy
602a0a4fb1 fix(backend/chat): Strip tool call noise from conversation history context 2026-02-10 14:11:27 +04:00
Zamil Majdy
8d7d531ae0 refactor(backend/chat): Remove unused max_context_messages config 2026-02-10 13:57:33 +04:00
Zamil Majdy
43153a12e0 fix(backend/chat): Remove manual context truncation from SDK path, let SDK handle compaction 2026-02-10 13:52:49 +04:00
Zamil Majdy
587e11c60a refactor(backend/chat): Extract MCP server name constants to avoid hardcoded strings 2026-02-10 12:12:08 +04:00
Zamil Majdy
57da545e02 Merge branch 'dev' into feat/copitlot-claude-code 2026-02-10 12:10:35 +04:00
Zamil Majdy
626980bf27 Merge branch 'dev' into feat/copitlot-claude-code 2026-02-09 19:26:52 +04:00
Swifty
e42b27af3c Merge branch 'dev' into feat/copitlot-claude-code 2026-02-09 09:12:23 +01:00
Zamil Majdy
34face15d2 fix lock 2026-02-09 11:45:59 +04:00
Zamil Majdy
7d32c83f95 fix(backend/chat): Handle non-serializable SDK objects in tool result output 2026-02-09 10:59:50 +04:00
Zamil Majdy
6e2a45b84e style(backend): Remove unused pytest import in execution_queue_test 2026-02-09 10:14:20 +04:00
Zamil Majdy
32f6532e9c Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copitlot-claude-code 2026-02-09 10:10:32 +04:00
Zamil Majdy
0bbe8a184d Merge dev and resolve poetry.lock conflict 2026-02-08 19:40:17 +04:00
Zamil Majdy
7592deed63 fix(backend/chat): Address remaining PR review comments
- Fix tool_call_id always being "sdk-call" by generating unique IDs per invocation
- Fix validation using original tool_name instead of clean_name in security hooks
- Fix duplicate StreamFinish in Anthropic fallback path
- Fix ImportError fallback returning plain dict instead of re-raising
- Extract _build_input_schema helper to deduplicate schema construction
- Add else branch for unhandled SDK message types for observability
- Truncate large tool results in conversation history to prevent context overflow
2026-02-08 19:39:10 +04:00
Zamil Majdy
b9c759ce4f fix(backend/chat): Address additional PR review comments
- Add terminal StreamFinish in adapt_sdk_stream if SDK ends without one
- Sanitize error message in adapt_sdk_stream exception handler
- Pass full JSON schema (type, properties, required) to tool decorator
2026-02-08 07:14:45 +04:00
Zamil Majdy
5efb80d47b fix(backend/chat): Address PR review comments for Claude SDK integration
- Add StreamFinish after ErrorMessage in response adapter
- Fix str.replace to removeprefix in security hooks
- Apply max_context_messages limit as safety guard in history formatting
- Add empty prompt guard before sending to SDK
- Sanitize error messages to avoid exposing internal details
- Fix fire-and-forget asyncio.create_task by storing task reference
- Fix tool_calls population on assistant messages
- Rewrite Anthropic fallback to persist messages and merge consecutive roles
- Only use ANTHROPIC_API_KEY for fallback (not OpenRouter keys)
- Fix IndexError when tool result content list is empty
2026-02-06 13:25:10 +04:00
Zamil Majdy
b49d8e2cba fix lock 2026-02-06 13:19:53 +04:00
Zamil Majdy
452544530d feat(chat/sdk): Enable native SDK context compaction
- Remove manual truncation in conversation history formatting
- SDK's automatic compaction handles context limits intelligently
- Add observability hooks:
  - PreCompact: Log when SDK triggers context compaction
  - PostToolUse: Log successful tool executions
  - PostToolUseFailure: Log and debug failed tool executions
- Update config: increase max_context_messages (SDK handles compaction)
2026-02-06 12:44:48 +04:00
Zamil Majdy
32ee7e6cf8 fix(chat): Remove aggressive stale task detection
The 60-second timeout was too aggressive and could incorrectly mark
legitimate long-running tool calls as stale. Relying on Redis TTL
(1 hour) for cleanup is sufficient and more reliable.
2026-02-06 11:45:54 +04:00
Zamil Majdy
670663c406 Merge dev and resolve poetry.lock conflict 2026-02-06 11:40:41 +04:00
Zamil Majdy
0dbe4cf51e feat(backend/chat): Add Claude Agent SDK integration for CoPilot
This PR adds Claude Agent SDK as the default backend for CoPilot chat completions,
replacing the direct OpenAI API integration.

Key changes:
- Add Claude Agent SDK service layer with MCP tool adapter
- Fix message persistence after tool calls (messages no longer disappear on refresh)
- Add OpenRouter tracing for session title generation
- Add security hooks for user context validation
- Add Anthropic fallback when SDK is not available
- Clean up excessive debug logging
2026-02-06 11:38:17 +04:00
97 changed files with 4872 additions and 9768 deletions

View File

@@ -66,13 +66,19 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*

View File

@@ -122,24 +122,6 @@ class ConnectionManager:
return len(connections)
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
"""Broadcast a message to all active websocket connections."""
message = WSMessage(
method=method,
data=data,
).model_dump_json()
connections = tuple(self.active_connections)
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()

View File

@@ -176,64 +176,30 @@ async def get_execution_analytics_config(
# Return with provider prefix for clarity
return f"{provider_name}: {model_name}"
# Get all models from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
from backend.server.v2.llm import db as llm_db
# Get the recommended model from the database (configurable via admin UI)
recommended_model_slug = await llm_db.get_recommended_model_slug()
# Build the available models list
first_enabled_slug = None
for registry_model in llm_registry.iter_dynamic_models():
# Only include enabled models in the list
if not registry_model.is_enabled:
continue
# Track first enabled model as fallback
if first_enabled_slug is None:
first_enabled_slug = registry_model.slug
model = LlmModel(registry_model.slug)
# Include all LlmModel values (no more filtering by hardcoded list)
recommended_model = LlmModel.GPT4O_MINI.value
for model in LlmModel:
label = generate_model_label(model)
# Add "(Recommended)" suffix to the recommended model
if registry_model.slug == recommended_model_slug:
if model.value == recommended_model:
label += " (Recommended)"
available_models.append(
ModelInfo(
value=registry_model.slug,
value=model.value,
label=label,
provider=registry_model.metadata.provider,
provider=model.provider,
)
)
# Sort models by provider and name for better UX
available_models.sort(key=lambda x: (x.provider, x.label))
# Handle case where no models are available
if not available_models:
logger.warning(
"No enabled LLM models found in registry. "
"Ensure models are configured and enabled in the LLM Registry."
)
# Provide a placeholder entry so admins see meaningful feedback
available_models.append(
ModelInfo(
value="",
label="No models available - configure in LLM Registry",
provider="none",
)
)
# Use the DB recommended model, or fallback to first enabled model
final_recommended = recommended_model_slug or first_enabled_slug or ""
return ExecutionAnalyticsConfig(
available_models=available_models,
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
default_user_prompt=DEFAULT_USER_PROMPT,
recommended_model=final_recommended,
recommended_model=recommended_model,
)

View File

@@ -1,593 +0,0 @@
import logging
import autogpt_libs.auth
import fastapi
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
tags=["llm", "admin"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
async def _refresh_runtime_state() -> None:
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
logger.info("Refreshing LLM registry runtime state...")
try:
# Refresh registry from database
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.blocks._base import BlockSchema
BlockSchema.clear_all_schema_caches()
logger.info("Cleared all block schema caches")
# Clear the /blocks endpoint cache so frontend gets updated schemas
try:
from backend.api.features.v1 import _get_cached_blocks
_get_cached_blocks.cache_clear()
logger.info("Cleared /blocks endpoint cache")
except Exception as e:
logger.warning("Failed to clear /blocks cache: %s", e)
# Clear the v2 builder caches
try:
from backend.api.features.builder import db as builder_db
builder_db._get_all_providers.cache_clear()
logger.info("Cleared v2 builder providers cache")
builder_db._build_cached_search_results.cache_clear()
logger.info("Cleared v2 builder search results cache")
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)
# Notify all executor services to refresh their registry cache
from backend.data.llm_registry import publish_registry_refresh_notification
await publish_registry_refresh_notification()
logger.info("Published registry refresh notification")
except Exception as exc:
logger.exception(
"LLM runtime state refresh failed; caches may be stale: %s", exc
)
@router.get(
"/providers",
summary="List LLM providers",
response_model=llm_model.LlmProvidersResponse,
)
async def list_llm_providers(include_models: bool = True):
providers = await llm_db.list_providers(include_models=include_models)
return llm_model.LlmProvidersResponse(providers=providers)
@router.post(
"/providers",
summary="Create LLM provider",
response_model=llm_model.LlmProvider,
)
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
provider = await llm_db.upsert_provider(request=request)
await _refresh_runtime_state()
return provider
@router.patch(
"/providers/{provider_id}",
summary="Update LLM provider",
response_model=llm_model.LlmProvider,
)
async def update_llm_provider(
provider_id: str,
request: llm_model.UpsertLlmProviderRequest,
):
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
await _refresh_runtime_state()
return provider
@router.delete(
"/providers/{provider_id}",
summary="Delete LLM provider",
response_model=dict,
)
async def delete_llm_provider(provider_id: str):
"""
Delete an LLM provider.
A provider can only be deleted if it has no associated models.
Delete all models from the provider first before deleting the provider.
"""
try:
await llm_db.delete_provider(provider_id)
await _refresh_runtime_state()
logger.info("Deleted LLM provider '%s'", provider_id)
return {"success": True, "message": "Provider deleted successfully"}
except ValueError as e:
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=500, detail=str(e))
@router.get(
"/models",
summary="List LLM models",
response_model=llm_model.LlmModelsResponse,
)
async def list_llm_models(
provider_id: str | None = fastapi.Query(default=None),
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = fastapi.Query(
default=50, ge=1, le=100, description="Number of models per page"
),
):
return await llm_db.list_models(
provider_id=provider_id, page=page, page_size=page_size
)
@router.post(
"/models",
summary="Create LLM model",
response_model=llm_model.LlmModel,
)
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
model = await llm_db.create_model(request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}",
summary="Update LLM model",
response_model=llm_model.LlmModel,
)
async def update_llm_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
):
model = await llm_db.update_model(model_id=model_id, request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}/toggle",
summary="Toggle LLM model availability",
response_model=llm_model.ToggleLlmModelResponse,
)
async def toggle_llm_model(
model_id: str,
request: llm_model.ToggleLlmModelRequest,
):
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
If disabling a model and `migrate_to_slug` is provided, all workflows using
this model will be migrated to the specified replacement model before disabling.
A migration record is created which can be reverted later using the revert endpoint.
Optional fields:
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
- `custom_credit_cost`: Custom pricing override for billing during migration
"""
try:
result = await llm_db.toggle_model(
model_id=model_id,
is_enabled=request.is_enabled,
migrate_to_slug=request.migrate_to_slug,
migration_reason=request.migration_reason,
custom_credit_cost=request.custom_credit_cost,
)
await _refresh_runtime_state()
if result.nodes_migrated > 0:
logger.info(
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
result.model.slug,
"enabled" if request.is_enabled else "disabled",
result.nodes_migrated,
result.migrated_to_slug,
result.migration_id,
)
return result
except ValueError as exc:
logger.warning("Model toggle validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to toggle model availability",
) from exc
@router.get(
"/models/{model_id}/usage",
summary="Get model usage count",
response_model=llm_model.LlmModelUsageResponse,
)
async def get_llm_model_usage(model_id: str):
"""Get the number of workflow nodes using this model."""
try:
return await llm_db.get_model_usage(model_id=model_id)
except ValueError as exc:
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to get model usage %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get model usage",
) from exc
@router.delete(
"/models/{model_id}",
summary="Delete LLM model and migrate workflows",
response_model=llm_model.DeleteLlmModelResponse,
)
async def delete_llm_model(
model_id: str,
replacement_model_slug: str | None = fastapi.Query(
default=None,
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
),
):
"""
Delete a model and optionally migrate workflows using it to a replacement model.
If no workflows are using this model, it can be deleted without providing a
replacement. If workflows exist, replacement_model_slug is required.
This endpoint:
1. Counts how many workflow nodes use the model being deleted
2. If nodes exist, validates the replacement model and migrates them
3. Deletes the model record
4. Refreshes all caches and notifies executors
Example: DELETE /api/llm/admin/models/{id}?replacement_model_slug=gpt-4o
Example (no usage): DELETE /api/llm/admin/models/{id}
"""
try:
result = await llm_db.delete_model(
model_id=model_id, replacement_model_slug=replacement_model_slug
)
await _refresh_runtime_state()
logger.info(
"Deleted model '%s' and migrated %d nodes to '%s'",
result.deleted_model_slug,
result.nodes_migrated,
result.replacement_model_slug,
)
return result
except ValueError as exc:
# Validation errors (model not found, replacement invalid, etc.)
logger.warning("Model deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete model and migrate workflows",
) from exc
# ============================================================================
# Migration Management Endpoints
# ============================================================================
@router.get(
"/migrations",
summary="List model migrations",
response_model=llm_model.LlmMigrationsResponse,
)
async def list_llm_migrations(
include_reverted: bool = fastapi.Query(
default=False, description="Include reverted migrations in the list"
),
):
"""
List all model migrations.
Migrations are created when disabling a model with the migrate_to_slug option.
They can be reverted to restore the original model configuration.
"""
try:
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
return llm_model.LlmMigrationsResponse(migrations=migrations)
except Exception as exc:
logger.exception("Failed to list migrations: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list migrations",
) from exc
@router.get(
"/migrations/{migration_id}",
summary="Get migration details",
response_model=llm_model.LlmModelMigration,
)
async def get_llm_migration(migration_id: str):
"""Get details of a specific migration."""
try:
migration = await llm_db.get_migration(migration_id)
if not migration:
raise fastapi.HTTPException(
status_code=404, detail=f"Migration '{migration_id}' not found"
)
return migration
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get migration",
) from exc
@router.post(
"/migrations/{migration_id}/revert",
summary="Revert a model migration",
response_model=llm_model.RevertMigrationResponse,
)
async def revert_llm_migration(
migration_id: str,
request: llm_model.RevertMigrationRequest | None = None,
):
"""
Revert a model migration, restoring affected workflows to their original model.
This only reverts the specific nodes that were part of the migration.
The source model must exist for the revert to succeed.
Options:
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
Response includes:
- `nodes_reverted`: Number of nodes successfully reverted
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
- `source_model_re_enabled`: Whether the source model was re-enabled
Requirements:
- Migration must not already be reverted
- Source model must exist
"""
try:
re_enable = request.re_enable_source_model if request else True
result = await llm_db.revert_migration(
migration_id,
re_enable_source_model=re_enable,
)
await _refresh_runtime_state()
logger.info(
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
"(%d already changed, source re-enabled=%s)",
migration_id,
result.nodes_reverted,
result.target_model_slug,
result.source_model_slug,
result.nodes_already_changed,
result.source_model_re_enabled,
)
return result
except ValueError as exc:
logger.warning("Migration revert validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to revert migration",
) from exc
# ============================================================================
# Creator Management Endpoints
# ============================================================================
@router.get(
"/creators",
summary="List model creators",
response_model=llm_model.LlmCreatorsResponse,
)
async def list_llm_creators():
"""
List all model creators.
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
This is distinct from providers who host/serve the models (e.g., OpenRouter).
"""
try:
creators = await llm_db.list_creators()
return llm_model.LlmCreatorsResponse(creators=creators)
except Exception as exc:
logger.exception("Failed to list creators: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list creators",
) from exc
@router.get(
"/creators/{creator_id}",
summary="Get creator details",
response_model=llm_model.LlmModelCreator,
)
async def get_llm_creator(creator_id: str):
"""Get details of a specific model creator."""
try:
creator = await llm_db.get_creator(creator_id)
if not creator:
raise fastapi.HTTPException(
status_code=404, detail=f"Creator '{creator_id}' not found"
)
return creator
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get creator",
) from exc
@router.post(
"/creators",
summary="Create model creator",
response_model=llm_model.LlmModelCreator,
)
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
"""
Create a new model creator.
A creator represents an organization that creates/trains AI models,
such as OpenAI, Anthropic, Meta, or Google.
"""
try:
creator = await llm_db.upsert_creator(request=request)
await _refresh_runtime_state()
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
return creator
except Exception as exc:
logger.exception("Failed to create creator: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to create creator",
) from exc
@router.patch(
"/creators/{creator_id}",
summary="Update model creator",
response_model=llm_model.LlmModelCreator,
)
async def update_llm_creator(
creator_id: str,
request: llm_model.UpsertLlmCreatorRequest,
):
"""Update an existing model creator."""
try:
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
await _refresh_runtime_state()
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
return creator
except Exception as exc:
logger.exception("Failed to update creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to update creator",
) from exc
@router.delete(
"/creators/{creator_id}",
summary="Delete model creator",
response_model=dict,
)
async def delete_llm_creator(creator_id: str):
"""
Delete a model creator.
This will remove the creator association from all models that reference it
(sets creatorId to NULL), but will not delete the models themselves.
"""
try:
await llm_db.delete_creator(creator_id)
await _refresh_runtime_state()
logger.info("Deleted model creator '%s'", creator_id)
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
except ValueError as exc:
logger.warning("Creator deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete creator",
) from exc
# ============================================================================
# Recommended Model Endpoints
# ============================================================================
@router.get(
"/recommended-model",
summary="Get recommended model",
response_model=llm_model.RecommendedModelResponse,
)
async def get_recommended_model():
"""
Get the currently recommended LLM model.
The recommended model is shown to users as the default/suggested option
in model selection dropdowns.
"""
try:
model = await llm_db.get_recommended_model()
return llm_model.RecommendedModelResponse(
model=model,
slug=model.slug if model else None,
)
except Exception as exc:
logger.exception("Failed to get recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get recommended model",
) from exc
@router.post(
"/recommended-model",
summary="Set recommended model",
response_model=llm_model.SetRecommendedModelResponse,
)
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
"""
Set a model as the recommended model.
This clears the recommended flag from any other model and sets it on
the specified model. The model must be enabled to be set as recommended.
The recommended model is displayed to users as the default/suggested
option in model selection dropdowns throughout the platform.
"""
try:
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
await _refresh_runtime_state()
logger.info(
"Set recommended model to '%s' (previous: %s)",
model.slug,
previous_slug or "none",
)
return llm_model.SetRecommendedModelResponse(
model=model,
previous_recommended_slug=previous_slug,
message=f"Model '{model.display_name}' is now the recommended model",
)
except ValueError as exc:
logger.warning("Set recommended model validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to set recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to set recommended model",
) from exc

View File

@@ -1,491 +0,0 @@
import json
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
import backend.api.features.admin.llm_routes as llm_routes
from backend.server.v2.llm import model as llm_model
from backend.util.models import Pagination
app = fastapi.FastAPI()
app.include_router(llm_routes.router, prefix="/admin/llm")
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_list_llm_providers_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM providers"""
# Mock the database function
mock_providers = [
{
"id": "provider-1",
"name": "openai",
"display_name": "OpenAI",
"description": "OpenAI LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
{
"id": "provider-2",
"name": "anthropic",
"display_name": "Anthropic",
"description": "Anthropic LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
]
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_providers",
new=AsyncMock(return_value=mock_providers),
)
response = client.get("/admin/llm/providers")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["providers"]) == 2
assert response_data["providers"][0]["name"] == "openai"
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_providers_success.json",
)
def test_list_llm_models_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM models with pagination"""
# Mock the database function - now returns LlmModelsResponse
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=True,
capabilities={},
metadata={},
costs=[
llm_model.LlmModelCost(
id="cost-1",
credit_cost=10,
credential_provider="openai",
metadata={},
)
],
)
mock_response = llm_model.LlmModelsResponse(
models=[mock_model],
pagination=Pagination(
total_items=1,
total_pages=1,
current_page=1,
page_size=50,
),
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_models",
new=AsyncMock(return_value=mock_response),
)
response = client.get("/admin/llm/models")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["models"]) == 1
assert response_data["models"][0]["slug"] == "gpt-4o"
assert response_data["pagination"]["total_items"] == 1
assert response_data["pagination"]["page_size"] == 50
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_models_success.json",
)
def test_create_llm_provider_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM provider"""
mock_provider = {
"id": "new-provider-id",
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
new=AsyncMock(return_value=mock_provider),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
response = client.post("/admin/llm/providers", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["name"] == "groq"
assert response_data["display_name"] == "Groq"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_provider_success.json",
)
def test_create_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM model"""
mock_model = {
"id": "new-model-id",
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-id",
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.create_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
response = client.post("/admin/llm/models", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["slug"] == "gpt-4.1-mini"
assert response_data["is_enabled"] is True
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_model_success.json",
)
def test_update_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful update of LLM model"""
mock_model = {
"id": "model-1",
"slug": "gpt-4o",
"display_name": "GPT-4o Updated",
"description": "Updated description",
"provider_id": "provider-1",
"context_window": 256000,
"max_output_tokens": 32768,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-1",
"credit_cost": 15,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.update_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"display_name": "GPT-4o Updated",
"description": "Updated description",
"context_window": 256000,
"max_output_tokens": 32768,
}
response = client.patch("/admin/llm/models/model-1", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["display_name"] == "GPT-4o Updated"
assert response_data["context_window"] == 256000
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"update_llm_model_success.json",
)
def test_toggle_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful toggling of LLM model enabled status"""
# Create a proper mock model object
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=False,
capabilities={},
metadata={},
costs=[],
)
# Create a proper ToggleLlmModelResponse
mock_response = llm_model.ToggleLlmModelResponse(
model=mock_model,
nodes_migrated=0,
migrated_to_slug=None,
migration_id=None,
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {"is_enabled": False}
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["model"]["is_enabled"] is False
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"toggle_llm_model_success.json",
)
def test_delete_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful deletion of LLM model with migration"""
# Create a proper DeleteLlmModelResponse
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="gpt-3.5-turbo",
deleted_model_display_name="GPT-3.5 Turbo",
replacement_model_slug="gpt-4o-mini",
nodes_migrated=42,
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete(
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
)
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
assert response_data["nodes_migrated"] == 42
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"delete_llm_model_success.json",
)
def test_delete_llm_model_validation_error(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails with proper error when validation fails"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
)
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
assert response.status_code == 400
assert "Replacement model 'invalid' not found" in response.json()["detail"]
def test_delete_llm_model_no_replacement_with_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails when nodes exist but no replacement is provided"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(
side_effect=ValueError(
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
"Please provide a replacement_model_slug to migrate them."
)
),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 400
assert "workflow node(s) are using it" in response.json()["detail"]
def test_delete_llm_model_no_replacement_no_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="unused-model",
deleted_model_display_name="Unused Model",
replacement_model_slug=None,
nodes_migrated=0,
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "unused-model"
assert response_data["nodes_migrated"] == 0
assert response_data["replacement_model_slug"] is None
mock_refresh.assert_called_once()

View File

@@ -20,7 +20,6 @@ from backend.blocks._base import (
)
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.data.llm_registry import get_all_model_slugs_for_validation
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -37,14 +36,7 @@ from .model import (
)
logger = logging.getLogger(__name__)
def _get_llm_models() -> list[str]:
"""Get LLM model names for search matching from the registry."""
return [
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
]
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
@@ -509,8 +501,8 @@ async def _get_static_counts():
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
# Check if query matches any value in llm_models from registry
if any(query in name for name in _get_llm_models()):
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False

View File

@@ -27,12 +27,11 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
)
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -93,6 +92,31 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
description="Model for the Claude Agent SDK path. If None, derives from "
"the `model` field by stripping the OpenRouter provider prefix.",
)
claude_agent_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
"Increase if tool outputs exceed the limit.",
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of sub-agent Tasks the SDK can spawn per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
@@ -138,6 +162,17 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -334,9 +334,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
)
return session
except Exception as e:
@@ -378,11 +377,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
)
return ChatSession.from_db(prisma_session, messages)
@@ -433,10 +430,9 @@ async def _save_session_to_db(
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
@@ -476,7 +472,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.info(f"Session {session_id} not in cache, checking database")
logger.debug(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -493,7 +489,6 @@ async def get_chat_session(
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -558,6 +553,40 @@ async def upsert_chat_session(
return session
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db.get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
@@ -664,13 +693,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Invalidate cache so next fetch gets updated title
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await _cache_session(cached)
except Exception as e:
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
return True
except Exception as e:

View File

@@ -1,5 +1,6 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
@@ -11,13 +12,22 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from backend.util.feature_flag import Flag, is_feature_enabled
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
from .model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
get_chat_session,
get_user_sessions,
)
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -41,6 +51,7 @@ from .tools.models import (
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .tracking import track_user_message
config = ChatConfig()
@@ -232,6 +243,10 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
@@ -301,10 +316,9 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -313,6 +327,25 @@ async def stream_chat_post(
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
session = await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
@@ -328,7 +361,7 @@ async def stream_chat_post(
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -349,15 +382,47 @@ async def stream_chat_post(
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
async for chunk in chat_service.stream_chat_completion(
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on LaunchDarkly flag (falls back to config default)
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
session_id,
request.message,
None, # Message already in session
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
session=session, # Pass session with message already added
context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
@@ -378,7 +443,7 @@ async def stream_chat_post(
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
@@ -405,6 +470,17 @@ async def stream_chat_post(
}
},
)
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
@@ -507,8 +583,14 @@ async def stream_chat_post(
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
@@ -752,8 +834,6 @@ async def stream_task(
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:

View File

@@ -0,0 +1,14 @@
"""Claude Agent SDK integration for CoPilot.
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]

View File

@@ -0,0 +1,203 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
)
responses.append(
StreamToolInputAvailable(
toolCallId=block.id,
toolName=tool_name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": tool_name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Prefer the stashed full output over the SDK's
# (potentially truncated) ToolResultBlock content.
# The SDK truncates large results, writing them to disk,
# which breaks frontend widget parsing.
output = pop_pending_tool_output(tool_name) or (
_extract_tool_output(block.content)
)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
output=output,
success=not (block.is_error or False),
)
)
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
else:
logger.warning(
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -0,0 +1,366 @@
"""Unit tests for the SDK response adapter."""
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start_and_step():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
def test_system_non_init_emits_nothing():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
assert results == []
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_step_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "hello"
def test_empty_text_block_emits_only_step():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
# Empty text skipped, but step still opens
assert len(results) == 1
assert isinstance(results[0], StreamStartStep)
def test_multiple_text_deltas_reuse_block_id():
adapter = _adapter()
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets step+start+delta, second only delta (block & step already started)
assert len(r1) == 3
assert isinstance(r1[0], StreamStartStep)
assert isinstance(r1[1], StreamTextStart)
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert r1[1].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
def test_tool_use_emits_input_start_and_available():
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
adapter = _adapter()
msg = AssistantMessage(
content=[
ToolUseBlock(
id="tool-1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"q": "x"},
)
],
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamToolInputStart)
assert results[1].toolCallId == "tool-1"
assert results[1].toolName == "find_agent" # prefix stripped
assert isinstance(results[2], StreamToolInputAvailable)
assert results[2].toolName == "find_agent" # prefix stripped
assert results[2].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
adapter.convert_message(text_msg) # opens step + text
results = adapter.convert_message(tool_msg)
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output_and_finish_step():
adapter = _adapter()
# First register the tool call (opens step) — SDK sends prefixed name
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
model="test",
)
adapter.convert_message(tool_msg)
# Now send tool result
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent" # prefix stripped
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_error():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
],
model="test",
)
)
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_list_content():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
result_msg = UserMessage(
content=[
ToolResultBlock(
tool_use_id="t1",
content=[
{"type": "text", "text": "line1"},
{"type": "text", "text": "line2"},
],
)
]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
assert isinstance(results[1], StreamFinishStep)
def test_string_user_message_ignored():
"""A plain string UserMessage (not tool results) produces no output."""
adapter = _adapter()
results = adapter.convert_message(UserMessage(content="hello"))
assert results == []
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish_step_and_finish():
adapter = _adapter()
# Start some text first (opens step)
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + FinishStep + StreamFinish
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamFinish)
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="s1",
result="API rate limited",
)
results = adapter.convert_message(msg)
# No step was open, so no FinishStep — just Error + Finish
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
assert isinstance(results[1], StreamFinish)
# -- Text after tools (new block ID) ----------------------------------------
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
# Send tool result (closes step)
adapter.convert_message(
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "after"
# -- Full conversation flow --------------------------------------------------
def test_full_conversation_flow():
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Assistant text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
)
)
# 3. Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(
id="t1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"query": "email"},
)
],
model="test",
)
)
)
# 4. Tool result
all_responses.extend(
adapter.convert_message(
UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
)
)
)
# 5. More text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
)
)
# 6. Result
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=500,
duration_api_ms=400,
is_error=False,
num_turns=2,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep", # step 1: text + tool call
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinishStep", # step 2 closed
"StreamFinish",
]

View File

@@ -0,0 +1,335 @@
"""Security hooks for Claude Agent SDK integration.
This module provides security hooks that validate tool calls before execution,
ensuring multi-user isolation and preventing unauthorized operations.
"""
import json
import logging
import os
import re
from collections.abc import Callable
from typing import Any, cast
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
logger = logging.getLogger(__name__)
# Tools that are blocked entirely (CLI/system access).
# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked
# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
BLOCKED_TOOLS = {
"Bash",
"bash",
"shell",
"exec",
"terminal",
"command",
}
# Tools allowed only when their path argument stays within the SDK workspace.
# The SDK uses these to handle oversized tool results (writes to tool-results/
# files, then reads them back) and for workspace file operations.
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
# Dangerous patterns in tool inputs
DANGEROUS_PATTERNS = [
r"sudo",
r"rm\s+-rf",
r"dd\s+if=",
r"/etc/passwd",
r"/etc/shadow",
r"chmod\s+777",
r"curl\s+.*\|.*sh",
r"wget\s+.*\|.*sh",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
def _deny(reason: str) -> dict[str, Any]:
"""Return a hook denial response."""
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": reason,
}
}
def _validate_workspace_path(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
# naturally uses relative paths like "test.txt" instead of absolute ones).
# Tilde paths (~/) are home-dir references, not relative — expand first.
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
tool_results_seg = os.sep + "tool-results" + os.sep
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
f"directory.{workspace_hint} "
"This is enforced by the platform and cannot be bypassed."
)
def _validate_tool_access(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Validate that a tool call is allowed.
Returns:
Empty dict to allow, or dict with hookSpecificOutput to deny
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
"Use the CoPilot-specific MCP tools instead."
)
# Workspace-scoped tools: allowed only within the SDK workspace directory
if tool_name in WORKSPACE_SCOPED_TOOLS:
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
# Check for dangerous patterns in tool input
# Use json.dumps for predictable format (str() produces Python repr)
input_str = json.dumps(tool_input) if tool_input else ""
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return _deny(
"[SECURITY] Input contains a blocked pattern. "
"This is enforced by the platform and cannot be bypassed."
)
return {}
def _validate_user_isolation(
tool_name: str, tool_input: dict[str, Any], user_id: str | None
) -> dict[str, Any]:
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
return {}
def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
Includes security validation and observability hooks:
- PreToolUse: Security validation before tool execution
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session counter for Task sub-agent spawns
task_spawn_count = 0
async def pre_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
nonlocal task_spawn_count
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Rate-limit Task (sub-agent) spawns per session
if tool_name == "Task":
task_spawn_count += 1
if task_spawn_count > max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} sub-tasks per session. "
"Please continue in the main conversation."
),
)
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
# Only block non-CoPilot tools; our MCP-registered tools
# (including Read for oversized results) are already sandboxed.
if not is_copilot_tool:
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
if result:
return cast(SyncHookJSONOutput, result)
# Validate user isolation
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when SDK triggers context compaction.
The SDK automatically compacts conversation history when it grows too large.
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---
async def stop_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Capture transcript path when SDK finishes processing.
The Stop hook fires while the CLI process is still alive, giving us
a reliable window to read the JSONL transcript before SIGTERM.
"""
_ = context, tool_use_id
transcript_path = cast(str, input_data.get("transcript_path", ""))
sdk_session_id = cast(str, input_data.get("session_id", ""))
if transcript_path and on_stop:
logger.info(
f"[SDK] Stop hook: transcript_path={transcript_path}, "
f"sdk_session_id={sdk_session_id[:12]}..."
)
on_stop(transcript_path, sdk_session_id)
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
if on_stop is not None:
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
return hooks
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
logger.warning("claude-agent-sdk not available, security hooks disabled")
return {}

View File

@@ -0,0 +1,165 @@
"""Unit tests for SDK security hooks."""
import os
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
# -- Blocked tools -----------------------------------------------------------
def test_blocked_tools_denied():
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"{tool} should be blocked"
def test_unknown_tool_allowed():
result = _validate_tool_access("SomeCustomTool", {})
assert result == {}
# -- Workspace-scoped tools --------------------------------------------------
def test_read_within_workspace_allowed():
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_write_within_workspace_allowed():
result = _validate_tool_access(
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_edit_within_workspace_allowed():
result = _validate_tool_access(
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_glob_within_workspace_allowed():
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_grep_within_workspace_allowed():
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_outside_workspace_denied():
result = _validate_tool_access(
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_write_outside_workspace_denied():
result = _validate_tool_access(
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_traversal_attack_denied():
result = _validate_tool_access(
"Read",
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
sdk_cwd=SDK_CWD,
)
assert _is_denied(result)
def test_no_path_allowed():
"""Glob/Grep without a path argument defaults to cwd — should pass."""
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_no_cwd_denies_absolute():
"""If no sdk_cwd is set, absolute paths are denied."""
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
assert _is_denied(result)
# -- Tool-results directory --------------------------------------------------
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
def test_bash_builtin_always_blocked():
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Dangerous patterns ------------------------------------------------------
def test_dangerous_pattern_blocked():
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
assert _is_denied(result)
def test_subprocess_pattern_blocked():
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
assert _is_denied(result)
# -- User isolation ----------------------------------------------------------
def test_workspace_path_traversal_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_normal_path_allowed():
result = _validate_user_isolation(
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
)
assert result == {}
def test_non_workspace_tool_passes_isolation():
result = _validate_user_isolation(
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}

View File

@@ -0,0 +1,751 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import os
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import (
_build_system_prompt,
_execute_long_running_tool_with_streaming,
_generate_session_title,
)
from ..tools.models import OperationPendingResponse, OperationStartedResponse
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
)
from .transcript import (
download_transcript,
read_transcript_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
@dataclass
class CapturedTranscript:
"""Info captured by the SDK Stop hook for stateless --resume."""
path: str = ""
sdk_session_id: str = ""
@property
def available(self) -> bool:
return bool(self.path)
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
_SDK_TOOL_SUPPLEMENT = """
## Tool notes
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
same working directory. Files created by one are readable by the other.
These files are **ephemeral** — they exist only for the current session.
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
for files that should persist across sessions (stored in cloud storage).
- Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
"""
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
existing background infrastructure: stream_registry (Redis Streams),
database persistence, and SSE reconnection. This means results survive
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
async def _callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
session_id = session.session_id
# --- Build user-friendly messages (matches non-SDK service) ---
if tool_name == "create_agent":
desc = args.get("description", "")
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = args.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Save OperationPendingResponse to chat history ---
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(
f"[SDK] Long-running tool {tool_name} delegated to background "
f"(operation_id={operation_id}, task_id={task_id})"
)
# --- Return OperationStartedResponse as MCP tool result ---
# This flows through SDK → response adapter → frontend, triggering
# the loading widget with SSE reconnection support.
started_json = OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id,
).model_dump_json()
return {
"content": [{"type": "text", "text": started_json}],
"isError": False,
}
return _callback
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
Uses ``config.claude_agent_model`` if set, otherwise derives from
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
"""
if config.claude_agent_model:
return config.claude_agent_model
model = config.model
if "/" in model:
return model.split("/", 1)[1]
return model
def _build_sdk_env() -> dict[str, str]:
"""Build env vars for the SDK CLI process.
Routes API calls through OpenRouter (or a custom base_url) using
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
This gives per-call token and cost tracking on the OpenRouter dashboard.
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
token are both present — otherwise returns an empty dict so the SDK
falls back to its default credentials.
"""
env: dict[str, str] = {}
if config.api_key and config.base_url:
# Strip /v1 suffix — SDK expects the base URL without a version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
if not base or not base.startswith("http"):
# Invalid base_url — don't override SDK defaults
return env
env["ANTHROPIC_BASE_URL"] = base
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
env["ANTHROPIC_API_KEY"] = ""
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
(single source of truth for path sanitization) and adds a defence-in-depth
assertion.
"""
cwd = make_session_path(session_id)
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
cwd = os.path.normpath(cwd)
if not cwd.startswith(_SDK_CWD_PREFIX):
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
return cwd
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
"""
import shutil
# Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = normalized.replace("/", "-")
# Construct the project directory path (known-safe home expansion)
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
) -> list[ChatMessage]:
"""Compress prior conversation messages if they exceed the token threshold.
Uses the shared compress_context() from prompt.py which supports:
- LLM summarization of old messages (keeps recent ones intact)
- Progressive content truncation as fallback
- Middle-out deletion as last resort
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
)
# Convert compressed dicts back to ChatMessages
return [
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return prior
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
"""Format conversation messages into a context prefix for the user message.
Returns a string like:
<conversation_history>
User: hello
You responded: Hi! How can I help?
</conversation_history>
Returns None if there are no messages to format.
"""
if not messages:
return None
lines: list[str] = []
for msg in messages:
if not msg.content:
continue
if msg.role == "user":
lines.append(f"User: {msg.content}")
elif msg.role == "assistant":
lines.append(f"You responded: {msg.content}")
# Skip tool messages — they're internal details
if not lines:
return None
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Initialise sdk_cwd before the try so the finally can reference it
# even if _make_sdk_cwd raises (in that case it stays as "").
sdk_cwd = ""
use_resume = False
try:
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
# Fail fast when no API credentials are available at all
sdk_env = _build_sdk_env()
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY "
"(or CHAT_API_KEY) for OpenRouter routing, "
"or ANTHROPIC_API_KEY for direct Anthropic access."
)
mcp_server = create_copilot_mcp_server()
sdk_model = _resolve_sdk_model()
# --- Transcript capture via Stop hook ---
captured_transcript = CapturedTranscript()
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
security_hooks = create_security_hooks(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
)
# --- Resume strategy: download transcript from bucket ---
resume_file: str | None = None
use_resume = False
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
transcript_content = await download_transcript(user_id, session_id)
if transcript_content and validate_transcript(transcript_content):
resume_file = write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
logger.info(
f"[SDK] Using --resume with transcript "
f"({len(transcript_content)} bytes)"
)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": COPILOT_TOOL_NAMES,
"disallowed_tools": ["Bash"],
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
}
if sdk_env:
sdk_options_kwargs["model"] = sdk_model
sdk_options_kwargs["env"] = sdk_env
if use_resume and resume_file:
sdk_options_kwargs["resume"] = resume_file
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query: with --resume the CLI already has full
# context, so we only send the new message. Without
# resume, compress history into a context prefix.
query_message = current_message
if not use_resume and len(session.messages) > 1:
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({len(session.messages)} messages) — "
f"no transcript available for --resume"
)
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
)
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Capture transcript while CLI is still alive ---
# Must happen INSIDE async with: close() sends SIGTERM
# which kills the CLI before it can flush the JSONL.
if (
config.claude_agent_use_resume
and user_id
and captured_transcript.available
):
# Give CLI time to flush JSONL writes before we read
await asyncio.sleep(0.5)
raw_transcript = read_transcript_file(captured_transcript.path)
if raw_transcript:
task = asyncio.create_task(
_upload_transcript_bg(user_id, session_id, raw_transcript)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
else:
logger.debug("[SDK] Stop hook fired but transcript not usable")
except ImportError:
raise RuntimeError(
"claude-agent-sdk is not installed. "
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
"to use the OpenAI-compatible fallback."
)
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
async def _upload_transcript_bg(
user_id: str, session_id: str, raw_content: str
) -> None:
"""Background task to strip progress entries and upload transcript."""
try:
await upload_transcript(user_id, session_id, raw_content)
except Exception as e:
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -0,0 +1,325 @@
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
Args:
user_id: Current user's ID.
session: Current chat session.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the stashed full output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the
frontend for proper widget rendering.
Returns ``None`` if nothing was stashed for *tool_name*.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return None
return pending.pop(tool_name, None)
async def _execute_tool_sync(
base_tool: BaseTool,
user_id: str | None,
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
session=session,
tool_call_id=effective_id,
**args,
)
text = (
result.output if isinstance(result.output, str) else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending[base_tool.name] = text
return {
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
def _mcp_error(message: str) -> dict[str, Any]:
return {
"content": [
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
],
"isError": True,
}
def create_tool_handler(base_tool: BaseTool):
"""Create an async handler function for a BaseTool.
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session = get_execution_context()
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
"""Build a JSON Schema input schema for a tool."""
return {
"type": "object",
"properties": base_tool.parameters.get("properties", {}),
"required": base_tool.parameters.get("required", []),
}
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under ~/.claude/projects/**/tool-results/
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
try:
with open(real_path) as f:
lines = f.readlines()
selected = lines[offset : offset + limit]
content = "".join(selected)
# Clean up to prevent accumulation in long-running pods
try:
os.remove(real_path)
except OSError:
pass
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
_READ_TOOL_DESCRIPTION = (
"Read a file from the local filesystem. "
"Use offset and limit to read specific line ranges for large files."
)
_READ_TOOL_SCHEMA = {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read",
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (0-indexed). Default: 0",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000",
},
},
"required": ["file_path"],
}
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
# Security hooks validate that file paths stay within sdk_cwd.
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
# which provides kernel-level network isolation via unshare --net.
# Task allows spawning sub-agents (rate-limited by security hooks).
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]

View File

@@ -0,0 +1,355 @@
"""JSONL transcript management for stateless multi-turn resume.
The Claude Code CLI persists conversations as JSONL files (one JSON object per
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
(progress entries, metadata), and upload the result to bucket storage. On the
next turn we download the transcript, write it to a temp file, and pass
``--resume`` so the CLI can reconstruct the full conversation.
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
import json
import logging
import os
import re
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
# Entry types that can be safely removed from the transcript without breaking
# the parentUuid conversation tree that ``--resume`` relies on.
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
# - file-history-snapshot: undo tracking metadata
# - queue-operation: internal queue bookkeeping
# - summary: session summaries
# - pr-link: PR link metadata
STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
# ---------------------------------------------------------------------------
# Progress stripping
# ---------------------------------------------------------------------------
def strip_progress_entries(content: str) -> str:
"""Remove progress/metadata entries from a JSONL transcript.
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
any remaining child entries so the ``parentUuid`` chain stays intact.
Typically reduces transcript size by ~30%.
"""
lines = content.strip().split("\n")
entries: list[dict] = []
for line in lines:
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
# Keep unparseable lines as-is (safety)
entries.append({"_raw": line})
stripped_uuids: set[str] = set()
uuid_to_parent: dict[str, str] = {}
kept: list[dict] = []
for entry in entries:
if "_raw" in entry:
kept.append(entry)
continue
uid = entry.get("uuid", "")
parent = entry.get("parentUuid", "")
entry_type = entry.get("type", "")
if uid:
uuid_to_parent[uid] = parent
if entry_type in STRIPPABLE_TYPES:
if uid:
stripped_uuids.add(uid)
else:
kept.append(entry)
# Reparent: walk up chain through stripped entries to find surviving ancestor
for entry in kept:
if "_raw" in entry:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
while parent in stripped_uuids:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
result_lines: list[str] = []
for entry in kept:
if "_raw" in entry:
result_lines.append(entry["_raw"])
else:
result_lines.append(json.dumps(entry, separators=(",", ":")))
return "\n".join(result_lines) + "\n"
# ---------------------------------------------------------------------------
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
# ---------------------------------------------------------------------------
def read_transcript_file(transcript_path: str) -> str | None:
"""Read a JSONL transcript file from disk.
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
or only contains metadata (≤2 lines with no conversation messages).
"""
if not transcript_path or not os.path.isfile(transcript_path):
logger.debug(f"[Transcript] File not found: {transcript_path}")
return None
try:
with open(transcript_path) as f:
content = f.read()
if not content.strip():
logger.debug(f"[Transcript] Empty file: {transcript_path}")
return None
lines = content.strip().split("\n")
if len(lines) < 2:
# Metadata-only files have 1 line (single queue-operation or snapshot).
logger.debug(
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
)
return None
# Quick structural validation — parse first and last lines.
json.loads(lines[0])
json.loads(lines[-1])
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
)
return content
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
return None
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
"""Sanitize an ID for safe use in file paths.
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
everything else and truncate to *max_len* so the result cannot introduce
path separators or other special characters.
"""
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
return cleaned or "unknown"
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def write_transcript_to_tempfile(
transcript_content: str,
session_id: str,
cwd: str,
) -> str | None:
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
The file lives in the session working directory so it is cleaned up
automatically when the session ends.
Returns the absolute path to the file, or ``None`` on failure.
"""
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
return None
try:
os.makedirs(real_cwd, exist_ok=True)
safe_id = _sanitize_id(session_id, max_len=8)
jsonl_path = os.path.realpath(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
return jsonl_path
except OSError as e:
logger.warning(f"[Transcript] Failed to write resume file: {e}")
return None
def validate_transcript(content: str | None) -> bool:
"""Check that a transcript has actual conversation messages.
A valid transcript for resume needs at least one user message and one
assistant message (not just queue-operation / file-history-snapshot
metadata).
"""
if not content or not content.strip():
return False
lines = content.strip().split("\n")
if len(lines) < 2:
return False
has_user = False
has_assistant = False
for line in lines:
try:
entry = json.loads(line)
msg_type = entry.get("type")
if msg_type == "user":
has_user = True
elif msg_type == "assistant":
has_assistant = True
except json.JSONDecodeError:
return False
return has_user and has_assistant
# ---------------------------------------------------------------------------
# Bucket storage (GCS / local via WorkspaceStorageBackend)
# ---------------------------------------------------------------------------
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript.
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
IDs are sanitized to hex+hyphen to prevent path traversal.
"""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.jsonl",
)
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects.
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
``local://workspace_id/file_id/filename``. Since we use deterministic
arguments we can reconstruct the same path for download/delete without
having stored the return value.
"""
from backend.util.workspace_storage import GCSWorkspaceStorage
wid, fid, fname = _storage_path_parts(user_id, session_id)
if isinstance(backend, GCSWorkspaceStorage):
blob = f"workspaces/{wid}/{fid}/{fname}"
return f"gcs://{backend.bucket_name}/{blob}"
else:
# LocalWorkspaceStorage returns local://{relative_path}
return f"local://{wid}/{fid}/{fname}"
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
"""Strip progress entries and upload transcript to bucket storage.
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
"""
from backend.util.workspace_storage import get_workspace_storage
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
logger.warning(
f"[Transcript] Skipping upload — stripped content is not a valid "
f"transcript for session {session_id}"
)
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
new_size = len(encoded)
# Check existing transcript size to avoid overwriting newer with older
path = _build_storage_path(user_id, session_id, storage)
try:
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping upload — existing transcript "
f"({len(existing)}B) >= new ({new_size}B) for session "
f"{session_id}"
)
return
except (FileNotFoundError, Exception):
pass # No existing transcript or retrieval error — proceed with upload
await storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
)
logger.info(
f"[Transcript] Uploaded {new_size} bytes "
f"(stripped from {len(content)}) for session {session_id}"
)
async def download_transcript(user_id: str, session_id: str) -> str | None:
"""Download transcript from bucket storage.
Returns the JSONL content string, or ``None`` if not found.
"""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
data = await storage.retrieve(path)
content = data.decode("utf-8")
logger.info(
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
)
return content
except FileNotFoundError:
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
return None
except Exception as e:
logger.warning(f"[Transcript] Failed to download transcript: {e}")
return None
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
await storage.delete(path)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
except Exception as e:
logger.warning(f"[Transcript] Failed to delete transcript: {e}")

View File

@@ -245,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str:
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
async def _build_system_prompt(
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
"""Build the full system prompt including business understanding if available.
Args:
user_id: The user ID for fetching business understanding
If "default" and this is the user's first session, will use "onboarding" instead.
user_id: The user ID for fetching business understanding.
has_conversation_history: Whether there's existing conversation history.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
Returns:
Tuple of (compiled prompt string, business understanding object)
@@ -266,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
if understanding:
context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
@@ -374,7 +380,6 @@ async def stream_chat_completion(
Raises:
NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
@@ -459,8 +464,9 @@ async def stream_chat_completion(
# Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message
if is_user_message and message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
user_messages = [m for m in session.messages if m.role == "user"]
first_user_msg = message or (user_messages[0].content if user_messages else None)
if is_user_message and first_user_msg and not session.title:
if len(user_messages) == 1:
# First user message - generate title in background
import asyncio
@@ -468,7 +474,7 @@ async def stream_chat_completion(
# Capture only the values we need (not the session object) to avoid
# stale data issues when the main flow modifies the session
captured_session_id = session_id
captured_message = message
captured_message = first_user_msg
captured_user_id = user_id
async def _update_title():
@@ -1237,7 +1243,7 @@ async def _stream_chat_chunks(
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from os import getenv
@@ -11,6 +12,8 @@ from .response_model import (
StreamTextDelta,
StreamToolOutputAvailable,
)
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
logger = logging.getLogger(__name__)
@@ -80,3 +83,96 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"""Test that the SDK --resume path captures and uses transcripts across turns.
Turn 1: Send a message containing a unique keyword.
Turn 2: Ask the model to recall that keyword — proving the transcript was
persisted and restored via --resume.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
from .config import ChatConfig
cfg = ChatConfig()
if not cfg.claude_agent_use_resume:
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "ZEPHYR42"
turn1_msg = (
f"Please remember this special keyword: {keyword}. "
"Just confirm you've noted it, keep your response brief."
)
turn1_text = ""
turn1_errors: list[str] = []
turn1_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
turn1_msg,
user_id=test_user_id,
):
if isinstance(chunk, StreamTextDelta):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn1_ended = True
assert turn1_ended, "Turn 1 did not finish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
# Wait for background upload task to complete (retry up to 5s)
transcript = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
break
assert transcript, (
"Transcript was not uploaded to bucket after turn 1 — "
"Stop hook may not have fired or transcript was too small"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
assert session, "Session not found after turn 1"
# --- Turn 2: ask model to recall the keyword ---
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
turn2_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
turn2_msg,
user_id=test_user_id,
session=session,
):
if isinstance(chunk, StreamTextDelta):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn2_ended = True
assert turn2_ended, "Turn 2 did not finish"
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (
f"Model did not recall keyword '{keyword}' in turn 2. "
f"Response: {turn2_text[:200]}"
)
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")

View File

@@ -814,6 +814,28 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError):
pass
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"

View File

@@ -9,6 +9,8 @@ from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -19,6 +21,7 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
@@ -43,9 +46,14 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Workspace tools for CoPilot file operations
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
"write_workspace_file": WriteWorkspaceFileTool(),

View File

@@ -0,0 +1,131 @@
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
read-only, writable workspace only, clean env, no network.
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
available (e.g. macOS development).
"""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
BashExecResponse,
ErrorResponse,
ToolResponseBase,
)
from backend.api.features.chat.tools.sandbox import (
get_workspace_dir,
has_full_sandbox,
run_sandboxed,
)
logger = logging.getLogger(__name__)
class BashExecTool(BaseTool):
"""Execute Bash commands in a bubblewrap sandbox."""
@property
def name(self) -> str:
return "bash_exec"
@property
def description(self) -> str:
if not has_full_sandbox():
return (
"Bash execution is DISABLED — bubblewrap sandbox is not "
"available on this platform. Do not call this tool."
)
return (
"Execute a Bash command or script in a bubblewrap sandbox. "
"Full Bash scripting is supported (loops, conditionals, pipes, "
"functions, etc.). "
"The sandbox shares the same working directory as the SDK Read/Write "
"tools — files created by either are accessible to both. "
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
"visible read-only, the per-session workspace is the only writable "
"path, environment variables are wiped (no secrets), all network "
"access is blocked at the kernel level, and resource limits are "
"enforced (max 64 processes, 512MB memory, 50MB file size). "
"Application code, configs, and other directories are NOT accessible. "
"To fetch web content, use the web_fetch tool instead. "
"Execution is killed after the timeout (default 30s, max 120s). "
"Returns stdout and stderr. "
"Useful for file manipulation, data processing with Unix tools "
"(grep, awk, sed, jq, etc.), and running shell scripts."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Bash command or script to execute.",
},
"timeout": {
"type": "integer",
"description": (
"Max execution time in seconds (default 30, max 120)."
),
"default": 30,
},
},
"required": ["command"],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
session_id = session.session_id if session else None
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
command: str = (kwargs.get("command") or "").strip()
timeout: int = kwargs.get("timeout", 30)
if not command:
return ErrorResponse(
message="No command provided.",
error="empty_command",
session_id=session_id,
)
workspace = get_workspace_dir(session_id or "default")
stdout, stderr, exit_code, timed_out = await run_sandboxed(
command=["bash", "-c", command],
cwd=workspace,
timeout=timeout,
)
return BashExecResponse(
message=(
"Execution timed out"
if timed_out
else f"Command executed (exit {exit_code})"
),
stdout=stdout,
stderr=stderr,
exit_code=exit_code,
timed_out=timed_out,
session_id=session_id,
)

View File

@@ -0,0 +1,127 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ResponseType,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.api.features.chat import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -146,6 +146,7 @@ class FindBlockTool(BaseTool):
id=block_id,
name=block.name,
description=block.description or "",
categories=[c.value for c in block.categories],
)
)

View File

@@ -41,6 +41,12 @@ class ResponseType(str, Enum):
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Base response model
@@ -335,6 +341,19 @@ class BlockInfoSummary(BaseModel):
id: str
name: str
description: str
categories: list[str]
input_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
)
class BlockListResponse(ToolResponseBase):
@@ -344,6 +363,10 @@ class BlockListResponse(ToolResponseBase):
blocks: list[BlockInfoSummary]
count: int
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs."
)
class BlockDetails(BaseModel):
@@ -430,3 +453,24 @@ class AsyncProcessingResponse(ToolResponseBase):
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""
type: ResponseType = ResponseType.WEB_FETCH
url: str
status_code: int
content_type: str
content: str
truncated: bool = False
class BashExecResponse(ToolResponseBase):
"""Response for bash_exec tool."""
type: ResponseType = ResponseType.BASH_EXEC
stdout: str
stderr: str
exit_code: int
timed_out: bool = False

View File

@@ -0,0 +1,265 @@
"""Sandbox execution utilities for code execution tools.
Provides filesystem + network isolated command execution using **bubblewrap**
(``bwrap``): whitelist-only filesystem (only system dirs visible read-only),
writable workspace only, clean environment, network blocked.
Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox`
and refuse to run if bubblewrap is not available.
"""
import asyncio
import logging
import os
import platform
import shutil
logger = logging.getLogger(__name__)
_DEFAULT_TIMEOUT = 30
_MAX_TIMEOUT = 120
# ---------------------------------------------------------------------------
# Sandbox capability detection (cached at first call)
# ---------------------------------------------------------------------------
_BWRAP_AVAILABLE: bool | None = None
def has_full_sandbox() -> bool:
"""Return True if bubblewrap is available (filesystem + network isolation).
On non-Linux platforms (macOS), always returns False.
"""
global _BWRAP_AVAILABLE
if _BWRAP_AVAILABLE is None:
_BWRAP_AVAILABLE = (
platform.system() == "Linux" and shutil.which("bwrap") is not None
)
return _BWRAP_AVAILABLE
WORKSPACE_PREFIX = "/tmp/copilot-"
def make_session_path(session_id: str) -> str:
"""Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`.
Shared by both the SDK working-directory setup and the sandbox tools so
they always resolve to the same directory for a given session.
Steps:
1. Strip all characters except ``[A-Za-z0-9-]``.
2. Construct ``/tmp/copilot-<safe_id>``.
3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised
sanitizer) to prevent path traversal.
Raises:
ValueError: If the resulting path escapes the prefix.
"""
import re
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
if not safe_id:
safe_id = "default"
path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}")
if not path.startswith(WORKSPACE_PREFIX):
raise ValueError(f"Session path escaped prefix: {path}")
return path
def get_workspace_dir(session_id: str) -> str:
"""Get or create the workspace directory for a session.
Uses :func:`make_session_path` — the same path the SDK uses — so that
bash_exec shares the workspace with the SDK file tools.
"""
workspace = make_session_path(session_id)
os.makedirs(workspace, exist_ok=True)
return workspace
# ---------------------------------------------------------------------------
# Bubblewrap command builder
# ---------------------------------------------------------------------------
# System directories mounted read-only inside the sandbox.
# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible.
_SYSTEM_RO_BINDS = [
"/usr", # binaries, libraries, Python interpreter
"/etc", # system config: ld.so, locale, passwd, alternatives
]
# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems.
# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind
# can't create a symlink target, so we detect and use --symlink instead.
# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2.
_COMPAT_PATHS = [
("/bin", "usr/bin"), # -> /usr/bin on Debian 13
("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13
("/lib", "usr/lib"), # -> /usr/lib on Debian 13
("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter
]
# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse.
# Applied via ulimit inside the sandbox before exec'ing the user command.
_RESOURCE_LIMITS = (
"ulimit -u 64" # max 64 processes (prevents fork bombs)
" -v 524288" # 512 MB virtual memory
" -f 51200" # 50 MB max file size (1024-byte blocks)
" -n 256" # 256 open file descriptors
" 2>/dev/null"
)
def _build_bwrap_command(
command: list[str], cwd: str, env: dict[str, str]
) -> list[str]:
"""Build a bubblewrap command with strict filesystem + network isolation.
Security model:
- **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``,
``/bin``, ``/lib``) are mounted read-only. Application code (``/app``),
home directories, ``/var``, ``/opt``, etc. are NOT accessible at all.
- **Writable workspace only**: the per-session workspace is the sole
writable path.
- **Clean environment**: ``--clearenv`` wipes all inherited env vars.
Only the explicitly-passed safe env vars are set inside the sandbox.
- **Network isolation**: ``--unshare-net`` blocks all network access.
- **Resource limits**: ulimit caps on processes (64), memory (512MB),
file size (50MB), and open FDs (256) to prevent fork bombs and abuse.
- **New session**: prevents terminal control escape.
- **Die with parent**: prevents orphaned sandbox processes.
"""
cmd = [
"bwrap",
# Create a new user namespace so bwrap can set up sandboxing
# inside unprivileged Docker containers (no CAP_SYS_ADMIN needed).
"--unshare-user",
# Wipe all inherited environment variables (API keys, secrets, etc.)
"--clearenv",
]
# Set only the safe env vars inside the sandbox
for key, value in env.items():
cmd.extend(["--setenv", key, value])
# System directories: read-only
for path in _SYSTEM_RO_BINDS:
cmd.extend(["--ro-bind", path, path])
# Compat paths: use --symlink when host path is a symlink (Debian 13),
# --ro-bind when it's a real directory (older distros).
for path, symlink_target in _COMPAT_PATHS:
if os.path.islink(path):
cmd.extend(["--symlink", symlink_target, path])
elif os.path.exists(path):
cmd.extend(["--ro-bind", path, path])
# Wrap the user command with resource limits:
# sh -c 'ulimit ...; exec "$@"' -- <original command>
# `exec "$@"` replaces the shell so there's no extra process overhead,
# and properly handles arguments with spaces.
limited_command = [
"sh",
"-c",
f'{_RESOURCE_LIMITS}; exec "$@"',
"--",
*command,
]
cmd.extend(
[
# Fresh virtual filesystems
"--dev",
"/dev",
"--proc",
"/proc",
"--tmpfs",
"/tmp",
# Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs.
# (workspace lives under /tmp/copilot-<session>)
"--bind",
cwd,
cwd,
# Isolation
"--unshare-net",
"--die-with-parent",
"--new-session",
"--chdir",
cwd,
"--",
*limited_command,
]
)
return cmd
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def run_sandboxed(
command: list[str],
cwd: str,
timeout: int = _DEFAULT_TIMEOUT,
env: dict[str, str] | None = None,
) -> tuple[str, str, int, bool]:
"""Run a command inside a bubblewrap sandbox.
Callers **must** check :func:`has_full_sandbox` before calling this
function. If bubblewrap is not available, this function raises
:class:`RuntimeError` rather than running unsandboxed.
Returns:
(stdout, stderr, exit_code, timed_out)
"""
if not has_full_sandbox():
raise RuntimeError(
"run_sandboxed() requires bubblewrap but bwrap is not available. "
"Callers must check has_full_sandbox() before calling this function."
)
timeout = min(max(timeout, 1), _MAX_TIMEOUT)
safe_env = {
"PATH": "/usr/local/bin:/usr/bin:/bin",
"HOME": cwd,
"TMPDIR": cwd,
"LANG": "en_US.UTF-8",
"PYTHONDONTWRITEBYTECODE": "1",
"PYTHONIOENCODING": "utf-8",
}
if env:
safe_env.update(env)
full_command = _build_bwrap_command(command, cwd, safe_env)
try:
proc = await asyncio.create_subprocess_exec(
*full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
env=safe_env,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(), timeout=timeout
)
stdout = stdout_bytes.decode("utf-8", errors="replace")
stderr = stderr_bytes.decode("utf-8", errors="replace")
return stdout, stderr, proc.returncode or 0, False
except asyncio.TimeoutError:
proc.kill()
await proc.communicate()
return "", f"Execution timed out after {timeout}s", -1, True
except RuntimeError:
raise
except Exception as e:
return "", f"Sandbox error: {e}", -1, False

View File

@@ -0,0 +1,151 @@
"""Web fetch tool — safely retrieve public web page content."""
import logging
from typing import Any
import aiohttp
import html2text
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ToolResponseBase,
WebFetchResponse,
)
from backend.util.request import Requests
logger = logging.getLogger(__name__)
# Limits
_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap
_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15)
# Content types we'll read as text
_TEXT_CONTENT_TYPES = {
"text/html",
"text/plain",
"text/xml",
"text/csv",
"text/markdown",
"application/json",
"application/xml",
"application/xhtml+xml",
"application/rss+xml",
"application/atom+xml",
}
def _is_text_content(content_type: str) -> bool:
base = content_type.split(";")[0].strip().lower()
return base in _TEXT_CONTENT_TYPES or base.startswith("text/")
def _html_to_text(html: str) -> str:
h = html2text.HTML2Text()
h.ignore_links = False
h.ignore_images = True
h.body_width = 0
return h.handle(html)
class WebFetchTool(BaseTool):
"""Safely fetch content from a public URL using SSRF-protected HTTP."""
@property
def name(self) -> str:
return "web_fetch"
@property
def description(self) -> str:
return (
"Fetch the content of a public web page by URL. "
"Returns readable text extracted from HTML by default. "
"Useful for reading documentation, articles, and API responses. "
"Only supports HTTP/HTTPS GET requests to public URLs "
"(private/internal network addresses are blocked)."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The public HTTP/HTTPS URL to fetch.",
},
"extract_text": {
"type": "boolean",
"description": (
"If true (default), extract readable text from HTML. "
"If false, return raw content."
),
"default": True,
},
},
"required": ["url"],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
url: str = (kwargs.get("url") or "").strip()
extract_text: bool = kwargs.get("extract_text", True)
session_id = session.session_id if session else None
if not url:
return ErrorResponse(
message="Please provide a URL to fetch.",
error="missing_url",
session_id=session_id,
)
try:
client = Requests(raise_for_status=False, retry_max_attempts=1)
response = await client.get(url, timeout=_REQUEST_TIMEOUT)
except ValueError as e:
# validate_url raises ValueError for SSRF / blocked IPs
return ErrorResponse(
message=f"URL blocked: {e}",
error="url_blocked",
session_id=session_id,
)
except Exception as e:
logger.warning(f"[web_fetch] Request failed for {url}: {e}")
return ErrorResponse(
message=f"Failed to fetch URL: {e}",
error="fetch_failed",
session_id=session_id,
)
content_type = response.headers.get("content-type", "")
if not _is_text_content(content_type):
return ErrorResponse(
message=f"Non-text content type: {content_type.split(';')[0]}",
error="unsupported_content_type",
session_id=session_id,
)
raw = response.content[:_MAX_CONTENT_BYTES]
text = raw.decode("utf-8", errors="replace")
if extract_text and "html" in content_type.lower():
text = _html_to_text(text)
return WebFetchResponse(
message=f"Fetched {url}",
url=response.url,
status_code=response.status,
content_type=content_type.split(";")[0].strip(),
content=text,
truncated=False,
session_id=session_id,
)

View File

@@ -88,7 +88,9 @@ class ListWorkspaceFilesTool(BaseTool):
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"List files in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read/Glob tools instead. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@@ -204,7 +206,9 @@ class ReadWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Read a file from the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read tool instead. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
@@ -378,7 +382,9 @@ class WriteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Write or create a file in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Write tool instead. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
@@ -523,7 +529,7 @@ class DeleteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Delete a file from the user's persistent workspace (cloud storage). "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."

View File

@@ -393,7 +393,6 @@ async def get_creators(
@router.get(
"/creator/{username}",
summary="Get creator details",
operation_id="getV2GetCreatorDetails",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)

View File

@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.llm_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -39,15 +38,13 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm.routes as public_llm_routes
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -118,27 +115,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.blocks._base import BlockSchema
BlockSchema.clear_all_schema_caches()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
# migrate_llm_models uses registry default model
from backend.blocks.llm import LlmModel
default_model_slug = llm_registry.get_default_model_slug()
if default_model_slug:
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
else:
logger.warning("Skipping LLM model migration: no default model available")
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
@@ -340,16 +321,6 @@ app.include_router(
tags=["v2", "executions", "review"],
prefix="/api/review",
)
app.include_router(
backend.api.features.admin.llm_routes.router,
tags=["v2", "admin", "llm"],
prefix="/api/llm/admin",
)
app.include_router(
public_llm_routes.router,
tags=["v2", "llm"],
prefix="/api",
)
app.include_router(
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
)

View File

@@ -79,39 +79,7 @@ async def event_broadcaster(manager: ConnectionManager):
payload=notification.payload,
)
async def registry_refresh_worker():
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
from backend.data.redis_client import connect_async
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
)
async for message in pubsub.listen():
if (
message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info(
"Broadcasting LLM registry refresh to all WebSocket clients"
)
await manager.broadcast_to_all(
method=WSMethod.NOTIFICATION,
data={
"type": "LLM_REGISTRY_REFRESH",
"event": "registry_updated",
},
)
await asyncio.gather(
execution_worker(),
notification_worker(),
registry_refresh_worker(),
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()

View File

@@ -133,26 +133,7 @@ class BlockInfo(BaseModel):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
@classmethod
def clear_schema_cache(cls) -> None:
"""Clear the cached JSON schema for this class."""
# Use None instead of {} because {} is truthy and would prevent regeneration
cls.cached_jsonschema = None # type: ignore
@staticmethod
def clear_all_schema_caches() -> None:
"""Clear cached JSON schemas for all BlockSchema subclasses."""
def clear_recursive(cls: type) -> None:
"""Recursively clear cache for class and all subclasses."""
if hasattr(cls, "clear_schema_cache"):
cls.clear_schema_cache()
for subclass in cls.__subclasses__():
clear_recursive(subclass)
clear_recursive(BlockSchema)
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:

View File

@@ -7,6 +7,7 @@ from backend.blocks._base import (
BlockSchemaOutput,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -15,7 +16,6 @@ from backend.blocks.llm import (
LlmModel,
LLMResponse,
llm_call,
llm_model_schema_extra,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
@@ -50,10 +50,9 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for evaluating the condition.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
@@ -83,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": LlmModel.default(),
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -4,17 +4,16 @@ import logging
import re
import secrets
from abc import ABC
from enum import Enum
from enum import Enum, EnumMeta
from json import JSONDecodeError
from typing import Any, Iterable, List, Literal, Optional
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
import anthropic
import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
from pydantic_core import CoreSchema, core_schema
from pydantic import BaseModel, SecretStr
from backend.blocks._base import (
Block,
@@ -23,8 +22,6 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data import llm_registry
from backend.data.llm_registry import ModelMetadata
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -69,123 +66,114 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
"""
Returns a CredentialsField for LLM providers.
The discriminator_mapping will be refreshed when the schema is generated
if it's empty, ensuring the LLM registry is loaded.
"""
# Get the mapping now - it may be empty initially, but will be refreshed
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
mapping = llm_registry.get_llm_discriminator_mapping()
return CredentialsField(
description="API key for the LLM provider.",
discriminator="model",
discriminator_mapping=mapping, # May be empty initially, refreshed later
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
)
def llm_model_schema_extra() -> dict[str, Any]:
return {"options": llm_registry.get_llm_model_schema_options()}
class ModelMetadata(NamedTuple):
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]
class LlmModelMeta(type):
"""
Metaclass for LlmModel that enables attribute-style access to dynamic models.
This allows code like `LlmModel.GPT4O` to work by converting the attribute
name to a slug format:
- GPT4O -> gpt-4o
- GPT4O_MINI -> gpt-4o-mini
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
"""
def __getattr__(cls, name: str):
# Don't intercept private/dunder attributes
if name.startswith("_"):
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
# Convert attribute name to slug format:
# 1. Lowercase: GPT4O -> gpt4o
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
slug = name.lower().replace("_", "-")
# Check for exact match in registry first (e.g., "o1" stays "o1")
registry_slugs = llm_registry.get_dynamic_model_slugs()
if slug in registry_slugs:
return cls(slug)
# If no exact match, try inserting hyphen between letter and digit
# e.g., gpt4o -> gpt-4o
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
return cls(transformed_slug)
def __iter__(cls):
"""Iterate over all models from the registry.
Yields LlmModel instances for each model in the dynamic registry.
Used by __get_pydantic_json_schema__ to build model metadata.
"""
for model in llm_registry.iter_dynamic_models():
yield cls(model.slug)
class LlmModelMeta(EnumMeta):
pass
class LlmModel(str, metaclass=LlmModelMeta):
"""
Dynamic LLM model type that accepts any model slug from the registry.
This is a string subclass (not an Enum) that allows any model slug value.
All models are managed via the LLM Registry in the database.
Usage:
model = LlmModel("gpt-4o") # Direct construction
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
model.value # Returns the slug string
model.provider # Returns the provider from registry
"""
def __new__(cls, value: str):
if isinstance(value, LlmModel):
return value
return str.__new__(cls, value)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""
Tell Pydantic how to validate LlmModel.
Accepts strings and converts them to LlmModel instances.
"""
return core_schema.no_info_after_validator_function(
cls, # The validator function (LlmModel constructor)
core_schema.str_schema(), # Accept string input
serialization=core_schema.to_string_ser_schema(), # Serialize as string
)
@property
def value(self) -> str:
"""Return the model slug (for compatibility with enum-style access)."""
return str(self)
@classmethod
def default(cls) -> "LlmModel":
"""
Get the default model from the registry.
Returns the recommended model if set, otherwise gpt-4o if available
and enabled, otherwise the first enabled model from the registry.
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
"""
from backend.data.llm_registry import get_default_model_slug
slug = get_default_model_slug()
if slug is None:
# Registry is empty (e.g., at module import time before DB connection).
# Fall back to gpt-4o for backward compatibility.
slug = "gpt-4o"
return cls(slug)
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
# Groq models
LLAMA3_3_70B = "llama-3.3-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant"
# Ollama models
OLLAMA_LLAMA3_3 = "llama3.3"
OLLAMA_LLAMA3_2 = "llama3.2"
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
@classmethod
def __get_pydantic_json_schema__(cls, schema, handler):
@@ -193,15 +181,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
llm_model_metadata = {}
for model in cls:
model_name = model.value
# Skip disabled models - only show enabled models in the picker
if not llm_registry.is_model_enabled(model_name):
continue
# Use registry directly with None check to gracefully handle
# missing metadata during startup/import before registry is populated
metadata = llm_registry.get_llm_model_metadata(model_name)
if metadata is None:
# Skip models without metadata (registry not yet populated)
continue
metadata = model.metadata
llm_model_metadata[model_name] = {
"creator": metadata.creator_name,
"creator_name": metadata.creator_name,
@@ -217,12 +197,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
@property
def metadata(self) -> ModelMetadata:
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
raise ValueError(
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
)
return MODEL_METADATA[self]
@property
def provider(self) -> str:
@@ -237,9 +212,300 @@ class LlmModel(str, metaclass=LlmModelMeta):
return self.metadata.max_output_tokens
# Default model constant for backward compatibility
# Uses the dynamic registry to get the default model
DEFAULT_LLM_MODEL = LlmModel.default()
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
LlmModel.O3_MINI: ModelMetadata(
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata(
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata(
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata(
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
),
LlmModel.GPT5_1: ModelMetadata(
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
),
LlmModel.GPT5: ModelMetadata(
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_MINI: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_NANO: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_CHAT: ModelMetadata(
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
),
LlmModel.GPT41: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
),
LlmModel.GPT41_MINI: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
), # gpt-4o-mini-2024-07-18
LlmModel.GPT4O: ModelMetadata(
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
), # gpt-4o-2024-08-06
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata(
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
), # claude-3-haiku-20240307
# https://docs.aimlapi.com/api-overview/model-database/text-models
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
),
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
"aiml_api",
128000,
40000,
"Llama 3.1 Nemotron 70B Instruct",
"AI/ML",
"Nvidia",
1,
),
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
),
# https://console.groq.com/docs/models
LlmModel.LLAMA3_3_70B: ModelMetadata(
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
),
LlmModel.LLAMA3_1_8B: ModelMetadata(
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
),
# https://ollama.com/library
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65535,
"Gemini 2.5 Flash Lite Preview 06.17",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
"open_router",
1048576,
8192,
"Gemini 2.0 Flash Lite 001",
"OpenRouter",
"Google",
1,
),
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
),
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
16000,
"Sonar Deep Research",
"OpenRouter",
"Perplexity",
3,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router",
131000,
4096,
"Hermes 3 Llama 3.1 405B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router",
12288,
12288,
"Hermes 3 Llama 3.1 70B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
LlmModel.KIMI_K2: ModelMetadata(
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
"open_router",
262144,
262144,
"Qwen 3 235B A22B Thinking 2507",
"OpenRouter",
"Qwen",
1,
),
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Scout 17B 16E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Maverick 17B 128E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
}
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class ToolCall(BaseModel):
@@ -332,11 +598,8 @@ def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
) -> bool | openai.Omit:
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
# parallel tool calls. Use regex to avoid false positives like "openai/gpt-oss".
is_o_series = re.match(r"^o\d", llm_model) is not None
if is_o_series or parallel_tool_calls is None:
return openai.NOT_GIVEN
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.omit
return parallel_tool_calls
@@ -371,93 +634,15 @@ async def llm_call(
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
# Get model metadata and check if enabled - with fallback support
# The model we'll actually use (may differ if original is disabled)
model_to_use = llm_model.value
# Check if model is in registry and if it's enabled
from backend.data.llm_registry import (
get_fallback_model_for_disabled,
get_model_info,
)
model_info = get_model_info(llm_model.value)
if model_info and not model_info.is_enabled:
# Model is disabled - try to find a fallback from the same provider
fallback = get_fallback_model_for_disabled(llm_model.value)
if fallback:
logger.warning(
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
)
model_to_use = fallback.slug
# Use fallback model's metadata
provider = fallback.metadata.provider
context_window = fallback.metadata.context_window
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
else:
# No fallback available - raise error
raise ValueError(
f"LLM model '{llm_model.value}' is disabled and no fallback model "
f"from the same provider is available. Please enable the model or "
f"select a different model in the block configuration."
)
else:
# Model is enabled or not in registry (legacy/static model)
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
except ValueError:
# Model not in cache - try refreshing the registry once if we have DB access
logger.warning(f"Model {llm_model.value} not found in registry cache")
# Try refreshing the registry if we have database access
from backend.data.db import is_connected
if is_connected():
try:
logger.info(
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
)
await llm_registry.refresh_llm_registry()
# Try again after refresh
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
logger.info(
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
)
except ValueError:
# Still not found after refresh
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry after refresh. "
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
)
except Exception as refresh_exc:
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
"Please ensure the model is added to the LLM registry via the admin UI."
) from refresh_exc
else:
# No DB access (e.g., in executor without direct DB connection)
# The registry should have been loaded on startup
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry cache. "
"The registry may need to be refreshed. Please contact support or try again later."
)
# Create effective model for model-specific parameter resolution (e.g., o-series check)
# This uses the resolved model_to_use which may differ from llm_model if fallback occurred
effective_model = LlmModel(model_to_use)
provider = llm_model.metadata.provider
context_window = llm_model.context_window
if compress_prompt_to_fit:
result = await compress_context(
messages=prompt,
target_tokens=context_window // 2,
target_tokens=llm_model.context_window // 2,
client=None, # Truncation-only, no LLM summarization
reserve=0, # Caller handles response token budget separately
)
if result.error:
logger.warning(
@@ -468,7 +653,7 @@ async def llm_call(
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
# model_max_output already set above
model_max_output = llm_model.max_output_tokens or int(2**15)
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
@@ -479,14 +664,14 @@ async def llm_call(
response_format = None
parallel_tool_calls = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
if force_json_output:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
@@ -533,7 +718,7 @@ async def llm_call(
)
try:
resp = await client.messages.create(
model=model_to_use,
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
@@ -597,7 +782,7 @@ async def llm_call(
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if force_json_output else None
response = await client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -619,7 +804,7 @@ async def llm_call(
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = await client.generate(
model=model_to_use,
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
@@ -641,7 +826,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -649,7 +834,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -683,7 +868,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -691,7 +876,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -718,7 +903,7 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "aiml_api":
client = openai.AsyncOpenAI(
client = openai.OpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
@@ -728,8 +913,8 @@ async def llm_call(
},
)
completion = await client.chat.completions.create(
model=model_to_use,
completion = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
@@ -757,11 +942,11 @@ async def llm_call(
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -812,10 +997,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
force_json_output: bool = SchemaField(
title="Restrict LLM to pure JSON output",
@@ -878,7 +1062,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1244,10 +1428,9 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
sys_prompt: str = SchemaField(
@@ -1341,9 +1524,8 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for summarizing the text.",
json_schema_extra=llm_model_schema_extra(),
)
focus: str = SchemaField(
title="Focus",
@@ -1559,9 +1741,8 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for the conversation.",
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_tokens: int | None = SchemaField(
@@ -1598,7 +1779,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1661,10 +1842,9 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for generating the list.",
advanced=True,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_retries: int = SchemaField(
@@ -1719,7 +1899,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

View File

@@ -226,10 +226,9 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default_factory=llm.LlmModel.default,
default=llm.DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm.llm_model_schema_extra(),
)
credentials: llm.AICredentials = llm.AICredentialsField()
multiple_tool_calls: bool = SchemaField(

View File

@@ -10,13 +10,13 @@ import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
AICredentials,
AICredentialsField,
LlmModel,
ModelMetadata,
)
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.data import llm_registry
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
Returns the provider name for the model in the required format for Stagehand:
provider/model_name
"""
model_metadata = self.metadata
model_metadata = MODEL_METADATA[LlmModel(self.value)]
model_name = self.value
if len(model_name.split("/")) == 1 and not self.value.startswith(
@@ -107,23 +107,19 @@ class StagehandRecommendedLlmModel(str, Enum):
@property
def provider(self) -> str:
return self.metadata.provider
return MODEL_METADATA[LlmModel(self.value)].provider
@property
def metadata(self) -> ModelMetadata:
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
# Fallback to LlmModel enum if registry lookup fails
return LlmModel(self.value).metadata
return MODEL_METADATA[LlmModel(self.value)]
@property
def context_window(self) -> int:
return self.metadata.context_window
return MODEL_METADATA[LlmModel(self.value)].context_window
@property
def max_output_tokens(self) -> int | None:
return self.metadata.max_output_tokens
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
class StagehandObserveBlock(Block):

View File

@@ -19,30 +19,6 @@ CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a
async def initialize_blocks() -> None:
# Refresh LLM registry before initializing blocks so blocks can use registry data
# This ensures the registry cache is populated even in executor context
try:
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
# Only refresh if we have DB access (check if Prisma is connected)
from backend.data.db import is_connected
if is_connected():
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
logger.info("LLM registry refreshed during block initialization")
else:
logger.warning(
"Prisma not connected, skipping LLM registry refresh during block initialization"
)
except Exception as exc:
logger.warning(
"Failed to refresh LLM registry during block initialization: %s", exc
)
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.blocks import get_blocks
from backend.sdk.cost_integration import sync_all_provider_costs
from backend.util.retry import func_retry

View File

@@ -1,8 +1,5 @@
import logging
from typing import Type
import prisma.models
from backend.blocks._base import Block, BlockCost, BlockCostType
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
from backend.blocks.ai_image_generator_block import AIImageGeneratorBlock, ImageGenModel
@@ -27,11 +24,13 @@ from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
from backend.blocks.llm import (
MODEL_METADATA,
AIConversationBlock,
AIListGeneratorBlock,
AIStructuredResponseGeneratorBlock,
AITextGeneratorBlock,
AITextSummarizerBlock,
LlmModel,
)
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
@@ -39,7 +38,6 @@ from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.blocks.video.narration import VideoNarrationBlock
from backend.data import llm_registry
from backend.integrations.credentials_store import (
aiml_api_credentials,
anthropic_credentials,
@@ -59,112 +57,210 @@ from backend.integrations.credentials_store import (
v0_credentials,
)
logger = logging.getLogger(__name__)
# =============== Configure the cost for each LLM Model call =============== #
PROVIDER_CREDENTIALS = {
"openai": openai_credentials,
"anthropic": anthropic_credentials,
"groq": groq_credentials,
"open_router": open_router_credentials,
"llama_api": llama_api_credentials,
"aiml_api": aiml_api_credentials,
"v0": v0_credentials,
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
LlmModel.GPT41: 2,
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_6_OPUS: 14,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.DEEPSEEK_R1_0528: 1,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
LlmModel.AMAZON_NOVA_LITE_V1: 1,
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,
LlmModel.QWEN3_CODER: 9,
# v0 by Vercel models
LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2,
LlmModel.V0_1_0_MD: 1,
}
# =============== Configure the cost for each LLM Model call =============== #
# All LLM costs now come from the database via llm_registry
LLM_COST: list[BlockCost] = []
for model in LlmModel:
if model not in MODEL_COST:
raise ValueError(f"Missing MODEL_COST for model: {model}")
async def _build_llm_costs_from_registry() -> list[BlockCost]:
"""
Build BlockCost list from all models in the LLM registry.
This function checks for active model migrations with customCreditCost overrides.
When a model has been migrated with a custom price, that price is used instead
of the target model's default cost.
"""
# Query active migrations with custom pricing overrides
migration_overrides: dict[str, int] = {}
try:
active_migrations = await prisma.models.LlmModelMigration.prisma().find_many(
where={
"isReverted": False,
"customCreditCost": {"not": None},
}
)
migration_overrides = {
migration.sourceModelSlug: migration.customCreditCost
for migration in active_migrations
if migration.customCreditCost is not None
}
if migration_overrides:
logger.info(
"Found %d active model migrations with custom pricing overrides",
len(migration_overrides),
)
except Exception as exc:
logger.warning(
"Failed to query model migration overrides: %s. Proceeding with default costs.",
exc,
exc_info=True,
)
costs: list[BlockCost] = []
for model in llm_registry.iter_dynamic_models():
for cost in model.costs:
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
if not credentials:
logger.warning(
"Skipping cost entry for %s due to unknown credentials provider %s",
model.slug,
cost.credential_provider,
)
continue
# Check if this model has a custom cost override from migration
cost_amount = migration_overrides.get(model.slug, cost.credit_cost)
if model.slug in migration_overrides:
logger.debug(
"Applying custom cost override for model %s: %d credits (default: %d)",
model.slug,
cost_amount,
cost.credit_cost,
)
cost_filter = {
"model": model.slug,
LLM_COST = (
# Anthropic Models
[
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": credentials.id,
"provider": credentials.provider,
"type": credentials.type,
"id": anthropic_credentials.id,
"provider": anthropic_credentials.provider,
"type": anthropic_credentials.type,
},
}
costs.append(
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter=cost_filter,
cost_amount=cost_amount,
)
)
return costs
async def refresh_llm_costs() -> None:
"""
Refresh LLM costs from the registry. All costs now come from the database.
This function also checks for active model migrations with custom pricing overrides
and applies them to ensure accurate billing.
"""
LLM_COST.clear()
LLM_COST.extend(await _build_llm_costs_from_registry())
# Initial load will happen after registry is refreshed at startup
# Don't call refresh_llm_costs() here - it will be called after registry refresh
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "anthropic"
]
# OpenAI Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": openai_credentials.id,
"provider": openai_credentials.provider,
"type": openai_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "openai"
]
# Groq Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {"id": groq_credentials.id},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "groq"
]
# Open Router Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": open_router_credentials.id,
"provider": open_router_credentials.provider,
"type": open_router_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "open_router"
]
# Llama API Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": llama_api_credentials.id,
"provider": llama_api_credentials.provider,
"type": llama_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "llama_api"
]
# v0 by Vercel Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": v0_credentials.id,
"provider": v0_credentials.provider,
"type": v0_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "v0"
]
# AI/ML Api Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": aiml_api_credentials.id,
"provider": aiml_api_credentials.provider,
"type": aiml_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "aiml_api"
]
)
# =============== This is the exhaustive list of cost for each Block =============== #

View File

@@ -1625,10 +1625,8 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Get all model slugs from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block

View File

@@ -1,72 +0,0 @@
"""
LLM Registry module for managing LLM models, providers, and costs dynamically.
This module provides a database-driven registry system for LLM models,
replacing hardcoded model configurations with a flexible admin-managed system.
"""
from backend.data.llm_registry.model import ModelMetadata
# Re-export for backwards compatibility
from backend.data.llm_registry.notifications import (
REGISTRY_REFRESH_CHANNEL,
publish_registry_refresh_notification,
subscribe_to_registry_refresh,
)
from backend.data.llm_registry.registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_default_model_slug,
get_dynamic_model_slugs,
get_fallback_model_for_disabled,
get_llm_discriminator_mapping,
get_llm_model_cost,
get_llm_model_metadata,
get_llm_model_schema_options,
get_model_info,
is_model_enabled,
iter_dynamic_models,
refresh_llm_registry,
register_static_costs,
register_static_metadata,
)
from backend.data.llm_registry.schema_utils import (
is_llm_model_field,
refresh_llm_discriminator_mapping,
refresh_llm_model_options,
update_schema_with_llm_registry,
)
__all__ = [
# Types
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Registry functions
"get_all_model_slugs_for_validation",
"get_default_model_slug",
"get_dynamic_model_slugs",
"get_fallback_model_for_disabled",
"get_llm_discriminator_mapping",
"get_llm_model_cost",
"get_llm_model_metadata",
"get_llm_model_schema_options",
"get_model_info",
"is_model_enabled",
"iter_dynamic_models",
"refresh_llm_registry",
"register_static_costs",
"register_static_metadata",
# Notifications
"REGISTRY_REFRESH_CHANNEL",
"publish_registry_refresh_notification",
"subscribe_to_registry_refresh",
# Schema utilities
"is_llm_model_field",
"refresh_llm_discriminator_mapping",
"refresh_llm_model_options",
"update_schema_with_llm_registry",
]

View File

@@ -1,25 +0,0 @@
"""Type definitions for LLM model metadata."""
from typing import Literal, NamedTuple
class ModelMetadata(NamedTuple):
"""Metadata for an LLM model.
Attributes:
provider: The provider identifier (e.g., "openai", "anthropic")
context_window: Maximum context window size in tokens
max_output_tokens: Maximum output tokens (None if unlimited)
display_name: Human-readable name for the model
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
creator_name: Name of the organization that created the model
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
"""
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]

View File

@@ -1,89 +0,0 @@
"""
Redis pub/sub notifications for LLM registry updates.
When models are added/updated/removed via the admin UI, this module
publishes notifications to Redis that all executor services subscribe to,
ensuring they refresh their registry cache in real-time.
"""
import asyncio
import logging
from typing import Any
from backend.data.redis_client import connect_async
logger = logging.getLogger(__name__)
# Redis channel name for LLM registry refresh notifications
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
"""
Publish a notification to Redis that the LLM registry has been updated.
All executor services subscribed to this channel will refresh their registry.
"""
try:
redis = await connect_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.info("Published LLM registry refresh notification to Redis")
except Exception as exc:
logger.warning(
"Failed to publish LLM registry refresh notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_refresh(
on_refresh: Any, # Async callable that takes no args
) -> None:
"""
Subscribe to Redis notifications for LLM registry updates.
This runs in a loop and processes messages as they arrive.
Args:
on_refresh: Async callable to execute when a refresh notification is received
"""
try:
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications on channel: %s",
REGISTRY_REFRESH_CHANNEL,
)
# Process messages in a loop
while True:
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info("Received LLM registry refresh notification")
try:
await on_refresh()
except Exception as exc:
logger.error(
"Error refreshing LLM registry from notification: %s",
exc,
exc_info=True,
)
except Exception as exc:
logger.warning(
"Error processing registry refresh message: %s", exc, exc_info=True
)
# Continue listening even if one message fails
await asyncio.sleep(1)
except Exception as exc:
logger.error(
"Failed to subscribe to LLM registry refresh notifications: %s",
exc,
exc_info=True,
)
raise

View File

@@ -1,388 +0,0 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Iterable
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
def _json_to_dict(value: Any) -> dict[str, Any]:
"""Convert Prisma Json type to dict, with fallback to empty dict."""
if value is None:
return {}
if isinstance(value, dict):
return value
# Prisma Json type should always be a dict at runtime
return dict(value) if value else {}
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
_static_metadata: dict[str, ModelMetadata] = {}
_static_costs: dict[str, int] = {}
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_discriminator_mapping: dict[str, str] = {}
_lock = asyncio.Lock()
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
"""Register static metadata for legacy models (deprecated)."""
_static_metadata.update({str(key): value for key, value in metadata.items()})
_refresh_cached_schema()
def register_static_costs(costs: dict[Any, int]) -> None:
"""Register static costs for legacy models (deprecated)."""
_static_costs.update({str(key): value for key, value in costs.items()})
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
options: list[dict[str, str]] = []
# Only include enabled models in the dropdown options
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
options.append(
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
)
for slug, metadata in _static_metadata.items():
if slug in _dynamic_models:
continue
options.append(
{
"label": slug,
"value": slug,
"group": metadata.provider,
"description": "",
}
)
return options
async def refresh_llm_registry() -> None:
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.debug("Found %d LLM model records in database", len(records))
except Exception as exc:
logger.error(
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
)
return
dynamic: dict[str, RegistryModel] = {}
for record in records:
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display_name = (
record.Provider.displayName if record.Provider else record.providerId
)
# Creator name: prefer Creator.name, fallback to provider display name
creator_name = (
record.Creator.name if record.Creator else provider_display_name
)
# Price tier: default to 1 (cheapest) if not set
price_tier = getattr(record, "priceTier", 1) or 1
# Clamp to valid range 1-3
price_tier = max(1, min(3, price_tier))
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
display_name=record.displayName,
provider_name=provider_display_name,
creator_name=creator_name,
price_tier=price_tier, # type: ignore[arg-type]
)
costs = tuple(
RegistryModelCost(
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=_json_to_dict(cost.metadata),
)
for cost in (record.Costs or [])
)
# Map creator if present
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
dynamic[record.slug] = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=_json_to_dict(record.capabilities),
extra_metadata=_json_to_dict(record.metadata),
provider_display_name=(
record.Provider.displayName
if record.Provider
else record.providerId
),
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
# Atomic swap - build new structures then replace references
# This ensures readers never see partially updated state
global _dynamic_models
_dynamic_models = dynamic
_refresh_cached_schema()
logger.info(
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
len(dynamic),
sum(1 for m in dynamic.values() if m.is_enabled),
sum(1 for m in dynamic.values() if not m.is_enabled),
)
def _refresh_cached_schema() -> None:
"""Refresh cached schema options and discriminator mapping."""
global _schema_options, _discriminator_mapping
# Build new structures
new_options = _build_schema_options()
new_mapping = {
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
}
for slug, metadata in _static_metadata.items():
new_mapping.setdefault(slug, metadata.provider)
# Atomic swap - replace references to ensure readers see consistent state
_schema_options = new_options
_discriminator_mapping = new_mapping
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
if slug in _dynamic_models:
return _dynamic_models[slug].metadata
return _static_metadata.get(slug)
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
"""Get model cost configuration by slug."""
if slug in _dynamic_models:
return _dynamic_models[slug].costs
cost_value = _static_costs.get(slug)
if cost_value is None:
return tuple()
return (
RegistryModelCost(
credit_cost=cost_value,
credential_provider="static",
credential_id=None,
credential_type=None,
currency=None,
metadata={},
),
)
def get_llm_model_schema_options() -> list[dict[str, str]]:
"""
Get schema options for LLM model selection dropdown.
Returns a copy of cached schema options that are refreshed when the registry is
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return list(_schema_options)
def get_llm_discriminator_mapping() -> dict[str, str]:
"""
Get discriminator mapping for LLM models.
Returns a copy of cached discriminator mapping that is refreshed when the registry
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return dict(_discriminator_mapping)
def get_dynamic_model_slugs() -> set[str]:
"""Get all dynamic model slugs from the registry."""
return set(_dynamic_models.keys())
def get_all_model_slugs_for_validation() -> set[str]:
"""
Get ALL model slugs (both enabled and disabled) for validation purposes.
This is used for JSON schema enum validation - we need to accept any known
model value (even disabled ones) so that existing graphs don't fail validation.
The actual fallback/enforcement happens at runtime in llm_call().
"""
all_slugs = set(_dynamic_models.keys())
all_slugs.update(_static_metadata.keys())
return all_slugs
def iter_dynamic_models() -> Iterable[RegistryModel]:
"""Iterate over all dynamic models in the registry."""
return tuple(_dynamic_models.values())
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
"""
Find a fallback model when the requested model is disabled.
Looks for an enabled model from the same provider. Prefers models with
similar names or capabilities if possible.
Args:
disabled_model_slug: The slug of the disabled model
Returns:
An enabled RegistryModel from the same provider, or None if no fallback found
"""
disabled_model = _dynamic_models.get(disabled_model_slug)
if not disabled_model:
return None
provider = disabled_model.metadata.provider
# Find all enabled models from the same provider
candidates = [
model
for model in _dynamic_models.values()
if model.is_enabled and model.metadata.provider == provider
]
if not candidates:
return None
# Sort by: prefer models with similar context window, then by name
candidates.sort(
key=lambda m: (
abs(m.metadata.context_window - disabled_model.metadata.context_window),
m.display_name.lower(),
)
)
return candidates[0]
def is_model_enabled(model_slug: str) -> bool:
"""Check if a model is enabled in the registry."""
model = _dynamic_models.get(model_slug)
if not model:
# Model not in registry - assume it's a static/legacy model and allow it
return True
return model.is_enabled
def get_model_info(model_slug: str) -> RegistryModel | None:
"""Get model info from the registry."""
return _dynamic_models.get(model_slug)
def get_default_model_slug() -> str | None:
"""
Get the default model slug to use for block defaults.
Returns the recommended model if set (configured via admin UI),
otherwise returns the first enabled model alphabetically.
Returns None if no models are available or enabled.
"""
# Return the recommended model if one is set and enabled
for model in _dynamic_models.values():
if model.is_recommended and model.is_enabled:
return model.slug
# No recommended model set - find first enabled model alphabetically
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
logger.warning(
"No recommended model set, using '%s' as default",
model.slug,
)
return model.slug
# No enabled models available
if _dynamic_models:
logger.error(
"No enabled models found in registry (%d models registered but all disabled)",
len(_dynamic_models),
)
else:
logger.error("No models registered in LLM registry")
return None

View File

@@ -1,130 +0,0 @@
"""
Helper utilities for LLM registry integration with block schemas.
This module handles the dynamic injection of discriminator mappings
and model options from the LLM registry into block schemas.
"""
import logging
from typing import Any
from backend.data.llm_registry.registry import (
get_all_model_slugs_for_validation,
get_default_model_slug,
get_llm_discriminator_mapping,
get_llm_model_schema_options,
)
logger = logging.getLogger(__name__)
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
"""
Check if a field is an LLM model selection field.
Returns True if the field has 'options' in json_schema_extra
(set by llm_model_schema_extra() in blocks/llm.py).
"""
if not hasattr(field_info, "json_schema_extra"):
return False
extra = field_info.json_schema_extra
if isinstance(extra, dict):
return "options" in extra
return False
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
"""
Refresh LLM model options from the registry.
Updates 'options' (for frontend dropdown) to show only enabled models,
but keeps the 'enum' (for validation) inclusive of ALL known models.
This is important because:
- Options: What users see in the dropdown (enabled models only)
- Enum: What values pass validation (all known models, including disabled)
Existing graphs may have disabled models selected - they should pass validation
and the fallback logic in llm_call() will handle using an alternative model.
"""
fresh_options = get_llm_model_schema_options()
if not fresh_options:
return
# Update options array (UI dropdown) - only enabled models
if "options" in field_schema:
field_schema["options"] = fresh_options
all_known_slugs = get_all_model_slugs_for_validation()
if all_known_slugs and "enum" in field_schema:
existing_enum = set(field_schema.get("enum", []))
combined_enum = existing_enum | all_known_slugs
field_schema["enum"] = sorted(combined_enum)
# Set the default value from the registry (gpt-4o if available, else first enabled)
# This ensures new blocks have a sensible default pre-selected
default_slug = get_default_model_slug()
if default_slug:
field_schema["default"] = default_slug
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
"""
Refresh discriminator_mapping for fields that use model-based discrimination.
The discriminator is already set when AICredentialsField() creates the field.
We only need to refresh the mapping when models are added/removed.
"""
if field_schema.get("discriminator") != "model":
return
# Always refresh the mapping to get latest models
fresh_mapping = get_llm_discriminator_mapping()
if fresh_mapping is not None:
field_schema["discriminator_mapping"] = fresh_mapping
def update_schema_with_llm_registry(
schema: dict[str, Any], model_class: type | None = None
) -> None:
"""
Update a JSON schema with current LLM registry data.
Refreshes:
1. Model options for LLM model selection fields (dropdown choices)
2. Discriminator mappings for credentials fields (model → provider)
Args:
schema: The JSON schema to update (mutated in-place)
model_class: The Pydantic model class (optional, for field introspection)
"""
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Refresh model options for LLM model fields
if model_class and hasattr(model_class, "model_fields"):
field_info = model_class.model_fields.get(field_name)
if field_info and is_llm_model_field(field_name, field_info):
try:
refresh_llm_model_options(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh LLM options for field %s: %s",
field_name,
exc,
)
# Refresh discriminator mapping for fields that use model discrimination
try:
refresh_llm_discriminator_mapping(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh discriminator mapping for field %s: %s",
field_name,
exc,
)

View File

@@ -39,7 +39,6 @@ from pydantic_core import (
)
from typing_extensions import TypedDict
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.request import parse_url
@@ -551,9 +550,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
else:
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# Ensure LLM discriminators are populated (delegates to shared helper)
update_schema_with_llm_registry(schema, model_class)
# Do not return anything, just mutate schema in place
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
@@ -708,20 +705,16 @@ def CredentialsField(
This is enforced by the `BlockSchema` base class.
"""
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
field_schema_extra: dict[str, Any] = {}
# Always include discriminator if provided
if discriminator is not None:
field_schema_extra["discriminator"] = discriminator
# Always include discriminator_mapping when discriminator is set (even if empty initially)
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
# Include other optional fields (only if not None)
if required_scopes:
field_schema_extra["credentials_scopes"] = list(required_scopes)
if discriminator_values:
field_schema_extra["discriminator_values"] = discriminator_values
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}
# Merge any json_schema_extra passed in kwargs
if "json_schema_extra" in kwargs:

View File

@@ -1,67 +0,0 @@
"""
Helper functions for LLM registry initialization in executor context.
These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import logging
from backend.blocks._base import BlockSchema
from backend.data import db, llm_registry
from backend.data.block import initialize_blocks
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.llm_registry import subscribe_to_registry_refresh
logger = logging.getLogger(__name__)
async def initialize_registry_for_executor() -> None:
"""
Initialize blocks and refresh LLM registry in the executor context.
This must run in the executor's event loop to have access to the database.
"""
try:
# Connect to database if not already connected
if not db.is_connected():
await db.connect()
logger.info("[GraphExecutor] Connected to database for registry refresh")
# Initialize blocks (internally refreshes LLM registry and costs)
await initialize_blocks()
logger.info("[GraphExecutor] Blocks initialized")
except Exception as exc:
logger.warning(
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
exc,
exc_info=True,
)
async def refresh_registry_on_notification() -> None:
"""Refresh LLM registry when notified via Redis pub/sub."""
try:
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_updates() -> None:
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
await subscribe_to_registry_refresh(refresh_registry_on_notification)

View File

@@ -708,20 +708,6 @@ class ExecutionProcessor:
)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
# Initialize LLM registry and subscribe to updates
from backend.executor.llm_registry_init import (
initialize_registry_for_executor,
subscribe_to_registry_updates,
)
asyncio.run_coroutine_threadsafe(
initialize_registry_for_executor(), self.node_execution_loop
)
asyncio.run_coroutine_threadsafe(
subscribe_to_registry_updates(), self.node_execution_loop
)
logger.info(f"[GraphExecutor] {self.tid} started")
@error_logged(swallow=False)

View File

@@ -1,935 +0,0 @@
from __future__ import annotations
from typing import Any, Iterable, Sequence, cast
import prisma
import prisma.models
from backend.data.db import transaction
from backend.server.v2.llm import model as llm_model
from backend.util.models import Pagination
def _json_dict(value: Any | None) -> dict[str, Any]:
if not value:
return {}
if isinstance(value, dict):
return value
return {}
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
return llm_model.LlmModelCost(
id=record.id,
unit=record.unit,
credit_cost=record.creditCost,
credential_provider=record.credentialProvider,
credential_id=record.credentialId,
credential_type=record.credentialType,
currency=record.currency,
metadata=_json_dict(record.metadata),
)
def _map_creator(
record: prisma.models.LlmModelCreator,
) -> llm_model.LlmModelCreator:
return llm_model.LlmModelCreator(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
website_url=record.websiteUrl,
logo_url=record.logoUrl,
metadata=_json_dict(record.metadata),
)
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
costs = []
if record.Costs:
costs = [_map_cost(cost) for cost in record.Costs]
creator = None
if hasattr(record, "Creator") and record.Creator:
creator = _map_creator(record.Creator)
return llm_model.LlmModel(
id=record.id,
slug=record.slug,
display_name=record.displayName,
description=record.description,
provider_id=record.providerId,
creator_id=record.creatorId,
creator=creator,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
capabilities=_json_dict(record.capabilities),
metadata=_json_dict(record.metadata),
costs=costs,
)
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
models: list[llm_model.LlmModel] = []
if record.Models:
models = [_map_model(model) for model in record.Models]
return llm_model.LlmProvider(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
default_credential_provider=record.defaultCredentialProvider,
default_credential_id=record.defaultCredentialId,
default_credential_type=record.defaultCredentialType,
supports_tools=record.supportsTools,
supports_json_output=record.supportsJsonOutput,
supports_reasoning=record.supportsReasoning,
supports_parallel_tool=record.supportsParallelTool,
metadata=_json_dict(record.metadata),
models=models,
)
async def list_providers(
include_models: bool = True, enabled_only: bool = False
) -> list[llm_model.LlmProvider]:
"""
List all LLM providers.
Args:
include_models: Whether to include models for each provider
enabled_only: If True, only include enabled models (for public routes)
"""
include: Any = None
if include_models:
model_where = {"isEnabled": True} if enabled_only else None
include = {
"Models": {
"include": {"Costs": True, "Creator": True},
"where": model_where,
}
}
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
return [_map_provider(record) for record in records]
async def upsert_provider(
request: llm_model.UpsertLlmProviderRequest,
provider_id: str | None = None,
) -> llm_model.LlmProvider:
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"defaultCredentialProvider": request.default_credential_provider,
"defaultCredentialId": request.default_credential_id,
"defaultCredentialType": request.default_credential_type,
"supportsTools": request.supports_tools,
"supportsJsonOutput": request.supports_json_output,
"supportsReasoning": request.supports_reasoning,
"supportsParallelTool": request.supports_parallel_tool,
"metadata": prisma.Json(request.metadata or {}),
}
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
if provider_id:
record = await prisma.models.LlmProvider.prisma().update(
where={"id": provider_id},
data=data,
include=include,
)
else:
record = await prisma.models.LlmProvider.prisma().create(
data=data,
include=include,
)
if record is None:
raise ValueError("Failed to create/update provider")
return _map_provider(record)
async def delete_provider(provider_id: str) -> bool:
"""
Delete an LLM provider.
A provider can only be deleted if it has no associated models.
Due to onDelete: Restrict on LlmModel.Provider, the database will
block deletion if models exist.
Args:
provider_id: UUID of the provider to delete
Returns:
True if deleted successfully
Raises:
ValueError: If provider not found or has associated models
"""
# Check if provider exists
provider = await prisma.models.LlmProvider.prisma().find_unique(
where={"id": provider_id},
include={"Models": True},
)
if not provider:
raise ValueError(f"Provider with id '{provider_id}' not found")
# Check if provider has any models
model_count = len(provider.Models) if provider.Models else 0
if model_count > 0:
raise ValueError(
f"Cannot delete provider '{provider.displayName}' because it has "
f"{model_count} model(s). Delete all models first."
)
# Safe to delete
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
return True
async def list_models(
provider_id: str | None = None,
enabled_only: bool = False,
page: int = 1,
page_size: int = 50,
) -> llm_model.LlmModelsResponse:
"""
List LLM models with pagination.
Args:
provider_id: Optional filter by provider ID
enabled_only: If True, only return enabled models (for public routes)
page: Page number (1-indexed)
page_size: Number of models per page
"""
where: Any = {}
if provider_id:
where["providerId"] = provider_id
if enabled_only:
where["isEnabled"] = True
# Get total count for pagination
total_items = await prisma.models.LlmModel.prisma().count(
where=where if where else None
)
# Calculate pagination
skip = (page - 1) * page_size
total_pages = (total_items + page_size - 1) // page_size if total_items > 0 else 0
records = await prisma.models.LlmModel.prisma().find_many(
where=where if where else None,
include={"Costs": True, "Creator": True},
skip=skip,
take=page_size,
)
models = [_map_model(record) for record in records]
return llm_model.LlmModelsResponse(
models=models,
pagination=Pagination(
total_items=total_items,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
def _cost_create_payload(
costs: Sequence[llm_model.LlmModelCostInput],
) -> dict[str, Iterable[dict[str, Any]]]:
create_items = []
for cost in costs:
item: dict[str, Any] = {
"unit": cost.unit,
"creditCost": cost.credit_cost,
"credentialProvider": cost.credential_provider,
}
# Only include optional fields if they have values
if cost.credential_id:
item["credentialId"] = cost.credential_id
if cost.credential_type:
item["credentialType"] = cost.credential_type
if cost.currency:
item["currency"] = cost.currency
# Handle metadata - use Prisma Json type
if cost.metadata is not None and cost.metadata != {}:
item["metadata"] = prisma.Json(cost.metadata)
create_items.append(item)
return {"create": create_items}
async def create_model(
request: llm_model.CreateLlmModelRequest,
) -> llm_model.LlmModel:
data: Any = {
"slug": request.slug,
"displayName": request.display_name,
"description": request.description,
"Provider": {"connect": {"id": request.provider_id}},
"contextWindow": request.context_window,
"maxOutputTokens": request.max_output_tokens,
"isEnabled": request.is_enabled,
"capabilities": prisma.Json(request.capabilities or {}),
"metadata": prisma.Json(request.metadata or {}),
"Costs": _cost_create_payload(request.costs),
}
if request.creator_id:
data["Creator"] = {"connect": {"id": request.creator_id}}
record = await prisma.models.LlmModel.prisma().create(
data=data,
include={"Costs": True, "Creator": True, "Provider": True},
)
return _map_model(record)
async def update_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
) -> llm_model.LlmModel:
# Build scalar field updates (non-relation fields)
scalar_data: Any = {}
if request.display_name is not None:
scalar_data["displayName"] = request.display_name
if request.description is not None:
scalar_data["description"] = request.description
if request.context_window is not None:
scalar_data["contextWindow"] = request.context_window
if request.max_output_tokens is not None:
scalar_data["maxOutputTokens"] = request.max_output_tokens
if request.is_enabled is not None:
scalar_data["isEnabled"] = request.is_enabled
if request.capabilities is not None:
scalar_data["capabilities"] = request.capabilities
if request.metadata is not None:
scalar_data["metadata"] = request.metadata
# Foreign keys can be updated directly as scalar fields
if request.provider_id is not None:
scalar_data["providerId"] = request.provider_id
if request.creator_id is not None:
# Empty string means remove the creator
scalar_data["creatorId"] = request.creator_id if request.creator_id else None
# If we have costs to update, we need to handle them separately
# because nested writes have different constraints
if request.costs is not None:
# Wrap cost replacement in a transaction for atomicity
async with transaction() as tx:
# First update scalar fields
if scalar_data:
await tx.llmmodel.update(
where={"id": model_id},
data=scalar_data,
)
# Then handle costs: delete existing and create new
await tx.llmmodelcost.delete_many(where={"llmModelId": model_id})
if request.costs:
cost_payload = _cost_create_payload(request.costs)
for cost_item in cost_payload["create"]:
cost_item["llmModelId"] = model_id
await tx.llmmodelcost.create(data=cast(Any, cost_item))
# Fetch the updated record (outside transaction)
record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
else:
# No costs update - simple update
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data=scalar_data,
include={"Costs": True, "Creator": True},
)
if not record:
raise ValueError(f"Model with id '{model_id}' not found")
return _map_model(record)
async def toggle_model(
model_id: str,
is_enabled: bool,
migrate_to_slug: str | None = None,
migration_reason: str | None = None,
custom_credit_cost: int | None = None,
) -> llm_model.ToggleLlmModelResponse:
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
Args:
model_id: UUID of the model to toggle
is_enabled: New enabled status
migrate_to_slug: If disabling and this is provided, migrate all workflows
using this model to the specified replacement model
migration_reason: Optional reason for the migration (e.g., "Provider outage")
custom_credit_cost: Optional custom pricing override for migrated workflows.
When set, the billing system should use this cost instead
of the target model's cost for affected nodes.
Returns:
ToggleLlmModelResponse with the updated model and optional migration stats
"""
import json
# Get the model being toggled
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
nodes_migrated = 0
migration_id: str | None = None
# If disabling with migration, perform migration first
if not is_enabled and migrate_to_slug:
# Validate replacement model exists and is enabled
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migrate_to_slug}
)
if not replacement:
raise ValueError(f"Replacement model '{migrate_to_slug}' not found")
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{migrate_to_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# Perform all operations atomically within a single transaction
# This ensures no nodes are missed between query and update
async with transaction() as tx:
# Get the IDs of nodes that will be migrated (inside transaction for consistency)
node_ids_result = await tx.query_raw(
"""
SELECT id
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
FOR UPDATE
""",
model.slug,
)
migrated_node_ids = (
[row["id"] for row in node_ids_result] if node_ids_result else []
)
nodes_migrated = len(migrated_node_ids)
if nodes_migrated > 0:
# Update by IDs to ensure we only update the exact nodes we queried
# Use JSON array and jsonb_array_elements_text for safe parameterization
node_ids_json = json.dumps(migrated_node_ids)
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
""",
migrate_to_slug,
node_ids_json,
)
record = await tx.llmmodel.update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
# Create migration record for revert capability
if nodes_migrated > 0:
migration_data: Any = {
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
migration_record = await tx.llmmodelmigration.create(
data=migration_data
)
migration_id = migration_record.id
else:
# Simple toggle without migration
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
if record is None:
raise ValueError(f"Model with id '{model_id}' not found")
return llm_model.ToggleLlmModelResponse(
model=_map_model(record),
nodes_migrated=nodes_migrated,
migrated_to_slug=migrate_to_slug if nodes_migrated > 0 else None,
migration_id=migration_id,
)
async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
"""Get usage count for a model."""
import prisma as prisma_module
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
model.slug,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
async def delete_model(
model_id: str, replacement_model_slug: str | None = None
) -> llm_model.DeleteLlmModelResponse:
"""
Delete a model and optionally migrate all AgentNodes using it to a replacement model.
This performs an atomic operation within a database transaction:
1. Validates the model exists
2. Counts affected nodes
3. If nodes exist, validates replacement model and migrates them
4. Deletes the LlmModel record (CASCADE deletes costs)
Args:
model_id: UUID of the model to delete
replacement_model_slug: Slug of the model to migrate to (required only if nodes use this model)
Returns:
DeleteLlmModelResponse with migration stats
Raises:
ValueError: If model not found, nodes exist but no replacement provided,
replacement not found, or replacement is disabled
"""
# 1. Get the model being deleted (validation - outside transaction)
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
deleted_slug = model.slug
deleted_display_name = model.displayName
# 2. Count affected nodes first to determine if replacement is needed
import prisma as prisma_module
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
deleted_slug,
)
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
# 3. Validate replacement model only if there are nodes to migrate
if nodes_to_migrate > 0:
if not replacement_model_slug:
raise ValueError(
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
f"are using it. Please provide a replacement_model_slug to migrate them."
)
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": replacement_model_slug}
)
if not replacement:
raise ValueError(f"Replacement model '{replacement_model_slug}' not found")
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{replacement_model_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# 4. Perform migration (if needed) and deletion atomically within a transaction
async with transaction() as tx:
# Migrate all AgentNode.constantInput->model to replacement
if nodes_to_migrate > 0 and replacement_model_slug:
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
# Delete the model (CASCADE will delete costs automatically)
await tx.llmmodel.delete(where={"id": model_id})
# Build appropriate message based on whether migration happened
if nodes_to_migrate > 0:
message = (
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
f"and migrated {nodes_to_migrate} workflow node(s) to '{replacement_model_slug}'."
)
else:
message = (
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}). "
f"No workflows were using this model."
)
return llm_model.DeleteLlmModelResponse(
deleted_model_slug=deleted_slug,
deleted_model_display_name=deleted_display_name,
replacement_model_slug=replacement_model_slug,
nodes_migrated=nodes_to_migrate,
message=message,
)
def _map_migration(
record: prisma.models.LlmModelMigration,
) -> llm_model.LlmModelMigration:
return llm_model.LlmModelMigration(
id=record.id,
source_model_slug=record.sourceModelSlug,
target_model_slug=record.targetModelSlug,
reason=record.reason,
node_count=record.nodeCount,
custom_credit_cost=record.customCreditCost,
is_reverted=record.isReverted,
created_at=record.createdAt.isoformat(),
reverted_at=record.revertedAt.isoformat() if record.revertedAt else None,
)
async def list_migrations(
include_reverted: bool = False,
) -> list[llm_model.LlmModelMigration]:
"""
List model migrations, optionally including reverted ones.
Args:
include_reverted: If True, include reverted migrations. Default is False.
Returns:
List of LlmModelMigration records
"""
where: Any = None if include_reverted else {"isReverted": False}
records = await prisma.models.LlmModelMigration.prisma().find_many(
where=where,
order={"createdAt": "desc"},
)
return [_map_migration(record) for record in records]
async def get_migration(migration_id: str) -> llm_model.LlmModelMigration | None:
"""Get a specific migration by ID."""
record = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
return _map_migration(record) if record else None
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> llm_model.RevertMigrationResponse:
"""
Revert a model migration, restoring affected nodes to their original model.
This only reverts the specific nodes that were migrated, not all nodes
currently using the target model.
Args:
migration_id: UUID of the migration to revert
re_enable_source_model: Whether to re-enable the source model if it's disabled
Returns:
RevertMigrationResponse with revert stats
Raises:
ValueError: If migration not found, already reverted, or source model not available
"""
import json
from datetime import datetime, timezone
# Get the migration record
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
if not migration:
raise ValueError(f"Migration with id '{migration_id}' not found")
if migration.isReverted:
raise ValueError(
f"Migration '{migration_id}' has already been reverted "
f"on {migration.revertedAt.isoformat() if migration.revertedAt else 'unknown date'}"
)
# Check if source model exists
source_model = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migration.sourceModelSlug}
)
if not source_model:
raise ValueError(
f"Source model '{migration.sourceModelSlug}' no longer exists. "
f"Cannot revert migration."
)
# Get the migrated node IDs (Prisma auto-parses JSONB to list)
migrated_node_ids: list[str] = (
migration.migratedNodeIds
if isinstance(migration.migratedNodeIds, list)
else json.loads(migration.migratedNodeIds) # type: ignore
)
if not migrated_node_ids:
raise ValueError("No nodes to revert in this migration")
# Track if we need to re-enable the source model
source_model_was_disabled = not source_model.isEnabled
should_re_enable = source_model_was_disabled and re_enable_source_model
source_model_re_enabled = False
# Perform revert atomically
async with transaction() as tx:
# Re-enable the source model if requested and it was disabled
if should_re_enable:
await tx.llmmodel.update(
where={"id": source_model.id},
data={"isEnabled": True},
)
source_model_re_enabled = True
# Update only the specific nodes that were migrated
# We need to check that they still have the target model (haven't been changed since)
# Use a single batch update for efficiency
# Use JSON array and jsonb_array_elements_text for safe parameterization
node_ids_json = json.dumps(migrated_node_ids)
result = await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
AND "constantInput"::jsonb->>'model' = $3
""",
migration.sourceModelSlug,
node_ids_json,
migration.targetModelSlug,
)
nodes_reverted = result if result else 0
# Mark migration as reverted
await tx.llmmodelmigration.update(
where={"id": migration_id},
data={
"isReverted": True,
"revertedAt": datetime.now(timezone.utc),
},
)
# Calculate nodes that were already changed since migration
nodes_already_changed = len(migrated_node_ids) - nodes_reverted
# Build appropriate message
message_parts = [
f"Successfully reverted migration: {nodes_reverted} node(s) restored "
f"from '{migration.targetModelSlug}' to '{migration.sourceModelSlug}'."
]
if nodes_already_changed > 0:
message_parts.append(
f" {nodes_already_changed} node(s) were already changed and not reverted."
)
if source_model_re_enabled:
message_parts.append(
f" Model '{migration.sourceModelSlug}' has been re-enabled."
)
return llm_model.RevertMigrationResponse(
migration_id=migration_id,
source_model_slug=migration.sourceModelSlug,
target_model_slug=migration.targetModelSlug,
nodes_reverted=nodes_reverted,
nodes_already_changed=nodes_already_changed,
source_model_re_enabled=source_model_re_enabled,
message="".join(message_parts),
)
# ============================================================================
# Creator CRUD operations
# ============================================================================
async def list_creators() -> list[llm_model.LlmModelCreator]:
"""List all LLM model creators."""
records = await prisma.models.LlmModelCreator.prisma().find_many(
order={"displayName": "asc"}
)
return [_map_creator(record) for record in records]
async def get_creator(creator_id: str) -> llm_model.LlmModelCreator | None:
"""Get a specific creator by ID."""
record = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
return _map_creator(record) if record else None
async def upsert_creator(
request: llm_model.UpsertLlmCreatorRequest,
creator_id: str | None = None,
) -> llm_model.LlmModelCreator:
"""Create or update a model creator."""
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"websiteUrl": request.website_url,
"logoUrl": request.logo_url,
"metadata": prisma.Json(request.metadata or {}),
}
if creator_id:
record = await prisma.models.LlmModelCreator.prisma().update(
where={"id": creator_id},
data=data,
)
else:
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
if record is None:
raise ValueError("Failed to create/update creator")
return _map_creator(record)
async def delete_creator(creator_id: str) -> bool:
"""
Delete a model creator.
This will set creatorId to NULL on all associated models (due to onDelete: SetNull).
Args:
creator_id: UUID of the creator to delete
Returns:
True if deleted successfully
Raises:
ValueError: If creator not found
"""
creator = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
if not creator:
raise ValueError(f"Creator with id '{creator_id}' not found")
await prisma.models.LlmModelCreator.prisma().delete(where={"id": creator_id})
return True
async def get_recommended_model() -> llm_model.LlmModel | None:
"""
Get the currently recommended LLM model.
Returns:
The recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
include={"Costs": True, "Creator": True},
)
return _map_model(record) if record else None
async def set_recommended_model(
model_id: str,
) -> tuple[llm_model.LlmModel, str | None]:
"""
Set a model as the recommended model.
This will clear the isRecommended flag from any other model and set it
on the specified model. The model must be enabled.
Args:
model_id: UUID of the model to set as recommended
Returns:
Tuple of (the updated model, previous recommended model slug or None)
Raises:
ValueError: If model not found or not enabled
"""
# First, verify the model exists and is enabled
target_model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}
)
if not target_model:
raise ValueError(f"Model with id '{model_id}' not found")
if not target_model.isEnabled:
raise ValueError(
f"Cannot set disabled model '{target_model.slug}' as recommended"
)
# Get the current recommended model (if any)
current_recommended = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True}
)
previous_slug = current_recommended.slug if current_recommended else None
# Use a transaction to ensure atomicity
async with transaction() as tx:
# Clear isRecommended from all models
await tx.llmmodel.update_many(
where={"isRecommended": True},
data={"isRecommended": False},
)
# Set the new recommended model
await tx.llmmodel.update(
where={"id": model_id},
data={"isRecommended": True},
)
# Fetch and return the updated model
updated_record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
if not updated_record:
raise ValueError("Failed to fetch updated model")
return _map_model(updated_record), previous_slug
async def get_recommended_model_slug() -> str | None:
"""
Get the slug of the currently recommended LLM model.
Returns:
The slug of the recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
)
return record.slug if record else None

View File

@@ -1,235 +0,0 @@
from __future__ import annotations
import re
from datetime import datetime
from typing import Any, Optional
import prisma.enums
import pydantic
from backend.util.models import Pagination
# Pattern for valid model slugs: alphanumeric start, then alphanumeric, dots, underscores, slashes, hyphens
SLUG_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$")
class LlmModelCost(pydantic.BaseModel):
id: str
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = None
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model (e.g., OpenAI, Meta)."""
id: str
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModel(pydantic.BaseModel):
id: str
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
creator: Optional[LlmModelCreator] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
id: str
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = None
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmProvidersResponse(pydantic.BaseModel):
providers: list[LlmProvider]
class LlmModelsResponse(pydantic.BaseModel):
models: list[LlmModel]
pagination: Optional[Pagination] = None
class LlmCreatorsResponse(pydantic.BaseModel):
creators: list[LlmModelCreator]
class UpsertLlmProviderRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = "api_key"
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class UpsertLlmCreatorRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCostInput(pydantic.BaseModel):
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = "api_key"
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class CreateLlmModelRequest(pydantic.BaseModel):
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCostInput]
@pydantic.field_validator("slug")
@classmethod
def validate_slug(cls, v: str) -> str:
if not v or len(v) > 100:
raise ValueError("Slug must be 1-100 characters")
if not SLUG_PATTERN.match(v):
raise ValueError(
"Slug must start with alphanumeric and contain only "
"alphanumeric characters, dots, underscores, slashes, or hyphens"
)
return v
class UpdateLlmModelRequest(pydantic.BaseModel):
display_name: Optional[str] = None
description: Optional[str] = None
context_window: Optional[int] = None
max_output_tokens: Optional[int] = None
is_enabled: Optional[bool] = None
capabilities: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
provider_id: Optional[str] = None
creator_id: Optional[str] = None
costs: Optional[list[LlmModelCostInput]] = None
class ToggleLlmModelRequest(pydantic.BaseModel):
is_enabled: bool
migrate_to_slug: Optional[str] = None
migration_reason: Optional[str] = None # e.g., "Provider outage"
# Custom pricing override for migrated workflows. When set, billing should use
# this cost instead of the target model's cost for affected nodes.
# See LlmModelMigration in schema.prisma for full documentation.
custom_credit_cost: Optional[int] = None
class ToggleLlmModelResponse(pydantic.BaseModel):
model: LlmModel
nodes_migrated: int = 0
migrated_to_slug: Optional[str] = None
migration_id: Optional[str] = None # ID of the migration record for revert
class DeleteLlmModelResponse(pydantic.BaseModel):
deleted_model_slug: str
deleted_model_display_name: str
replacement_model_slug: Optional[str] = None
nodes_migrated: int
message: str
class LlmModelUsageResponse(pydantic.BaseModel):
model_slug: str
node_count: int
# Migration tracking models
class LlmModelMigration(pydantic.BaseModel):
id: str
source_model_slug: str
target_model_slug: str
reason: Optional[str] = None
node_count: int
# Custom pricing override - billing should use this instead of target model's cost
custom_credit_cost: Optional[int] = None
is_reverted: bool = False
created_at: datetime
reverted_at: Optional[datetime] = None
class LlmMigrationsResponse(pydantic.BaseModel):
migrations: list[LlmModelMigration]
class RevertMigrationRequest(pydantic.BaseModel):
re_enable_source_model: bool = (
True # Whether to re-enable the source model if disabled
)
class RevertMigrationResponse(pydantic.BaseModel):
migration_id: str
source_model_slug: str
target_model_slug: str
nodes_reverted: int
nodes_already_changed: int = (
0 # Nodes that were modified since migration (not reverted)
)
source_model_re_enabled: bool = False # Whether the source model was re-enabled
message: str
class SetRecommendedModelRequest(pydantic.BaseModel):
model_id: str
class SetRecommendedModelResponse(pydantic.BaseModel):
model: LlmModel
previous_recommended_slug: Optional[str] = None
message: str
class RecommendedModelResponse(pydantic.BaseModel):
model: Optional[LlmModel] = None
slug: Optional[str] = None

View File

@@ -1,29 +0,0 @@
import autogpt_libs.auth
import fastapi
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models(
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = fastapi.Query(
default=50, ge=1, le=100, description="Number of models per page"
),
):
"""List all enabled LLM models available to users."""
return await llm_db.list_models(enabled_only=True, page=page, page_size=page_size)
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""List all LLM providers with their enabled models."""
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -38,6 +38,7 @@ class Flag(str, Enum):
AGENT_ACTIVITY = "agent-activity"
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
CHAT = "chat"
COPILOT_SDK = "copilot-sdk"
def is_configured() -> bool:

View File

@@ -1,81 +0,0 @@
-- CreateEnum
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
-- CreateTable
CREATE TABLE "LlmProvider" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"defaultCredentialProvider" TEXT,
"defaultCredentialId" TEXT,
"defaultCredentialType" TEXT,
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
);
-- CreateTable
CREATE TABLE "LlmModel" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"slug" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"providerId" TEXT NOT NULL,
"contextWindow" INTEGER NOT NULL,
"maxOutputTokens" INTEGER,
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
);
-- CreateTable
CREATE TABLE "LlmModelCost" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
"creditCost" INTEGER NOT NULL,
"credentialProvider" TEXT NOT NULL,
"credentialId" TEXT,
"credentialType" TEXT,
"currency" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
"llmModelId" TEXT NOT NULL,
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
-- CreateIndex
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
-- CreateIndex
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
-- CreateIndex
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModelCost_llmModelId_credentialProvider_unit_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit");
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -1,226 +0,0 @@
-- Seed LLM Registry from existing hard-coded data
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
-- Insert Providers
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
VALUES
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
ON CONFLICT ("name") DO NOTHING;
-- Insert Models (using CTEs to reference provider IDs)
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
model_slug,
model_display_name,
NULL,
p."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb
FROM (VALUES
-- OpenAI models
('o3', 'O3', 'openai', 200000, 100000),
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
('o1', 'O1', 'openai', 200000, 100000),
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1000000, 32768),
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
-- Anthropic models
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
-- AI/ML API models
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
-- Groq models
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
-- Ollama models
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
('llama3', 'Llama 3', 'ollama', 8192, NULL),
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
-- OpenRouter models
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
-- Llama API models
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
-- v0 models
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
ON CONFLICT ("slug") DO NOTHING;
-- Insert Costs (using CTEs to reference model IDs)
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM (VALUES
-- OpenAI costs
('o3', 4),
('o3-mini', 2),
('o1', 16),
('o1-mini', 4),
('gpt-5-2025-08-07', 2),
('gpt-5.1-2025-11-13', 5),
('gpt-5-mini-2025-08-07', 1),
('gpt-5-nano-2025-08-07', 1),
('gpt-5-chat-latest', 5),
('gpt-4.1-2025-04-14', 2),
('gpt-4.1-mini-2025-04-14', 1),
('gpt-4o-mini', 1),
('gpt-4o', 3),
('gpt-4-turbo', 10),
('gpt-3.5-turbo', 1),
-- Anthropic costs
('claude-opus-4-1-20250805', 21),
('claude-opus-4-20250514', 21),
('claude-sonnet-4-20250514', 5),
('claude-haiku-4-5-20251001', 4),
('claude-opus-4-5-20251101', 14),
('claude-sonnet-4-5-20250929', 9),
('claude-3-7-sonnet-20250219', 5),
('claude-3-haiku-20240307', 1),
-- AI/ML API costs
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
-- Groq costs
('llama-3.3-70b-versatile', 1),
('llama-3.1-8b-instant', 1),
-- Ollama costs
('llama3.3', 1),
('llama3.2', 1),
('llama3', 1),
('llama3.1:405b', 1),
('dolphin-mistral:latest', 1),
-- OpenRouter costs
('google/gemini-2.5-pro-preview-03-25', 4),
('google/gemini-3-pro-preview', 5),
('mistralai/mistral-nemo', 1),
('cohere/command-r-08-2024', 1),
('cohere/command-r-plus-08-2024', 3),
('deepseek/deepseek-chat', 2),
('perplexity/sonar', 1),
('perplexity/sonar-pro', 5),
('perplexity/sonar-deep-research', 10),
('nousresearch/hermes-3-llama-3.1-405b', 1),
('nousresearch/hermes-3-llama-3.1-70b', 1),
('amazon/nova-lite-v1', 1),
('amazon/nova-micro-v1', 1),
('amazon/nova-pro-v1', 1),
('microsoft/wizardlm-2-8x22b', 1),
('gryphe/mythomax-l2-13b', 1),
('meta-llama/llama-4-scout', 1),
('meta-llama/llama-4-maverick', 1),
('x-ai/grok-4', 9),
('x-ai/grok-4-fast', 1),
('x-ai/grok-4.1-fast', 1),
('x-ai/grok-code-fast-1', 1),
('moonshotai/kimi-k2', 1),
('qwen/qwen3-235b-a22b-thinking-2507', 1),
('qwen/qwen3-coder', 9),
('google/gemini-2.5-flash', 1),
('google/gemini-2.0-flash-001', 1),
('google/gemini-2.5-flash-lite-preview-06-17', 1),
('google/gemini-2.0-flash-lite-001', 1),
('deepseek/deepseek-r1-0528', 1),
('openai/gpt-oss-120b', 1),
('openai/gpt-oss-20b', 1),
-- Llama API costs
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
('Llama-3.3-8B-Instruct', 1),
('Llama-3.3-70B-Instruct', 1),
-- v0 costs
('v0-1.5-md', 1),
('v0-1.5-lg', 2),
('v0-1.0-md', 1)
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId"
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;

View File

@@ -1,25 +0,0 @@
-- CreateTable
CREATE TABLE "LlmModelMigration" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"sourceModelSlug" TEXT NOT NULL,
"targetModelSlug" TEXT NOT NULL,
"reason" TEXT,
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
"nodeCount" INTEGER NOT NULL,
"customCreditCost" INTEGER,
"isReverted" BOOLEAN NOT NULL DEFAULT false,
"revertedAt" TIMESTAMP(3),
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_idx" ON "LlmModelMigration"("sourceModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_isReverted_idx" ON "LlmModelMigration"("isReverted");

View File

@@ -1,127 +0,0 @@
-- Add LlmModelCreator table
-- Creator represents who made/trained the model (e.g., OpenAI, Meta)
-- This is distinct from Provider who hosts/serves the model (e.g., OpenRouter)
-- Create the LlmModelCreator table
CREATE TABLE "LlmModelCreator" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"websiteUrl" TEXT,
"logoUrl" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
);
-- Create unique index on name
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
-- Add creatorId column to LlmModel
ALTER TABLE "LlmModel" ADD COLUMN "creatorId" TEXT;
-- Add foreign key constraint
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- Create index on creatorId
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- Seed creators based on known model creators
INSERT INTO "LlmModelCreator" ("id", "updatedAt", "name", "displayName", "description", "websiteUrl", "metadata")
VALUES
(gen_random_uuid(), CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT models', 'https://openai.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude models', 'https://anthropic.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama models', 'https://ai.meta.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini models', 'https://deepmind.google', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'mistral', 'Mistral AI', 'Creator of Mistral models', 'https://mistral.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command models', 'https://cohere.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek models', 'https://deepseek.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar models', 'https://perplexity.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'qwen', 'Qwen (Alibaba)', 'Creator of Qwen models', 'https://qwenlm.github.io', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova models', 'https://aws.amazon.com/bedrock', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of WizardLM models', 'https://microsoft.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'moonshot', 'Moonshot AI', 'Creator of Kimi models', 'https://moonshot.cn', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nous_research', 'Nous Research', 'Creator of Hermes models', 'https://nousresearch.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 models', 'https://vercel.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cognitive_computations', 'Cognitive Computations', 'Creator of Dolphin models', 'https://erichartford.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', '{}')
ON CONFLICT ("name") DO NOTHING;
-- Update existing models with their creators
-- OpenAI models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'openai')
WHERE "slug" LIKE 'gpt-%' OR "slug" LIKE 'o1%' OR "slug" LIKE 'o3%' OR "slug" LIKE 'openai/%';
-- Anthropic models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'anthropic')
WHERE "slug" LIKE 'claude-%';
-- Meta/Llama models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'meta')
WHERE "slug" LIKE 'llama%' OR "slug" LIKE 'Llama%' OR "slug" LIKE 'meta-llama/%' OR "slug" LIKE '%/llama-%';
-- Google models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'google')
WHERE "slug" LIKE 'google/%' OR "slug" LIKE 'gemini%';
-- Mistral models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'mistral')
WHERE "slug" LIKE 'mistral%' OR "slug" LIKE 'mistralai/%';
-- Cohere models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cohere')
WHERE "slug" LIKE 'cohere/%' OR "slug" LIKE 'command-%';
-- DeepSeek models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'deepseek')
WHERE "slug" LIKE 'deepseek/%' OR "slug" LIKE 'deepseek-%';
-- Perplexity models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'perplexity')
WHERE "slug" LIKE 'perplexity/%' OR "slug" LIKE 'sonar%';
-- Qwen models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'qwen')
WHERE "slug" LIKE 'Qwen/%' OR "slug" LIKE 'qwen/%';
-- xAI/Grok models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'xai')
WHERE "slug" LIKE 'x-ai/%' OR "slug" LIKE 'grok%';
-- Amazon models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'amazon')
WHERE "slug" LIKE 'amazon/%' OR "slug" LIKE 'nova-%';
-- Microsoft models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'microsoft')
WHERE "slug" LIKE 'microsoft/%' OR "slug" LIKE 'wizardlm%';
-- Moonshot models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'moonshot')
WHERE "slug" LIKE 'moonshotai/%' OR "slug" LIKE 'kimi%';
-- NVIDIA models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nvidia')
WHERE "slug" LIKE 'nvidia/%' OR "slug" LIKE '%nemotron%';
-- Nous Research models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nous_research')
WHERE "slug" LIKE 'nousresearch/%' OR "slug" LIKE 'hermes%';
-- Vercel/v0 models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'vercel')
WHERE "slug" LIKE 'v0-%';
-- Dolphin models (Cognitive Computations / Eric Hartford)
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cognitive_computations')
WHERE "slug" LIKE 'dolphin-%';
-- Gryphe models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'gryphe')
WHERE "slug" LIKE 'gryphe/%' OR "slug" LIKE 'mythomax%';

View File

@@ -1,4 +0,0 @@
-- CreateIndex
-- Index for efficient LLM model lookups on AgentNode.constantInput->>'model'
-- This improves performance of model migration queries in the LLM registry
CREATE INDEX "AgentNode_constantInput_model_idx" ON "AgentNode" ((("constantInput"->>'model')));

View File

@@ -1,52 +0,0 @@
-- Add GPT-5.2 model and update O3 slug
-- This migration adds the new GPT-5.2 model added in dev branch
-- Update O3 slug to match dev branch format
UPDATE "LlmModel"
SET "slug" = 'o3-2025-04-16'
WHERE "slug" = 'o3';
-- Update cost reference for O3 if needed
-- (costs are linked by model ID, so no update needed)
-- Add GPT-5.2 model
WITH provider_id AS (
SELECT "id" FROM "LlmProvider" WHERE "name" = 'openai'
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
'gpt-5.2-2025-12-11',
'GPT 5.2',
'OpenAI GPT-5.2 model',
p."id",
400000,
128000,
true,
'{}'::jsonb,
'{}'::jsonb
FROM provider_id p
ON CONFLICT ("slug") DO NOTHING;
-- Add cost for GPT-5.2
WITH model_id AS (
SELECT m."id", p."name" as provider_name
FROM "LlmModel" m
JOIN "LlmProvider" p ON p."id" = m."providerId"
WHERE m."slug" = 'gpt-5.2-2025-12-11'
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
3, -- Same cost tier as GPT-5.1
m.provider_name,
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM model_id m
WHERE NOT EXISTS (
SELECT 1 FROM "LlmModelCost" c WHERE c."llmModelId" = m."id"
);

View File

@@ -1,11 +0,0 @@
-- Add isRecommended field to LlmModel table
-- This allows admins to mark a model as the recommended default
ALTER TABLE "LlmModel" ADD COLUMN "isRecommended" BOOLEAN NOT NULL DEFAULT false;
-- Set gpt-4o-mini as the default recommended model (if it exists)
UPDATE "LlmModel" SET "isRecommended" = true WHERE "slug" = 'gpt-4o-mini' AND "isEnabled" = true;
-- Create unique partial index to enforce only one recommended model at the database level
-- This prevents multiple rows from having isRecommended = true
CREATE UNIQUE INDEX "LlmModel_single_recommended_idx" ON "LlmModel" ("isRecommended") WHERE "isRecommended" = true;

View File

@@ -1,61 +0,0 @@
-- Add new columns to LlmModel table for extended model metadata
-- These columns support the LLM Picker UI enhancements
-- Add priceTier column: 1=cheapest, 2=medium, 3=expensive
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "priceTier" INTEGER NOT NULL DEFAULT 1;
-- Add creatorId column for model creator relationship (if not exists)
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "creatorId" TEXT;
-- Add isRecommended column (if not exists)
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "isRecommended" BOOLEAN NOT NULL DEFAULT FALSE;
-- Add index on creatorId if not exists
CREATE INDEX IF NOT EXISTS "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- Add foreign key for creatorId if not exists
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'LlmModel_creatorId_fkey') THEN
-- Only add FK if LlmModelCreator table exists
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'LlmModelCreator') THEN
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
END IF;
END IF;
END $$;
-- Update priceTier values for existing models based on original MODEL_METADATA
-- Tier 1 = cheapest, Tier 2 = medium, Tier 3 = expensive
-- OpenAI models
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o3';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'o3-mini';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'o1';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o1-mini';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-5.2';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5.1';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-mini';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-nano';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5-chat-latest';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" LIKE 'gpt-4.1%';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-4o-mini';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-4o';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-4-turbo';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-3.5-turbo';
-- Anthropic models
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude-opus%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude-sonnet%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude%-4-5-sonnet%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude%-haiku%';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'claude-3-haiku-20240307';
-- OpenRouter models - Pro/expensive tiers
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'google/gemini%-pro%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%command-r-plus%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%sonar-pro%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%sonar-deep-research%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'x-ai/grok-4';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%qwen3-coder%';

View File

@@ -1,6 +0,0 @@
-- Add composite index on LlmModelMigration for optimized active migration queries
-- This index improves performance when querying for non-reverted migrations by model slug
-- Used by the billing system to apply customCreditCost overrides
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");

View File

@@ -1,61 +0,0 @@
-- Sync LLM models with latest dev branch changes
-- This migration adds new models and removes deprecated ones
-- Remove models that were deleted from dev
DELETE FROM "LlmModelCost" WHERE "llmModelId" IN (
SELECT "id" FROM "LlmModel" WHERE "slug" IN ('o3', 'o3-mini', 'claude-3-7-sonnet-20250219')
);
DELETE FROM "LlmModel" WHERE "slug" IN ('o3', 'o3-mini', 'claude-3-7-sonnet-20250219');
-- Add new models from dev
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
model_slug,
model_display_name,
NULL,
p."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb
FROM (VALUES
-- New OpenAI model
('gpt-5.2-2025-12-11', 'GPT 5.2', 'openai', 400000, 128000),
-- New Anthropic model
('claude-opus-4-6', 'Claude 4.6 Opus', 'anthropic', 200000, 64000)
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
ON CONFLICT ("slug") DO NOTHING;
-- Add costs for new models
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM (VALUES
-- New model costs (estimate based on similar models)
('gpt-5.2-2025-12-11', 5), -- Similar to GPT 5.1
('claude-opus-4-6', 21) -- Similar to other Opus 4.x models
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId"
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;

View File

@@ -897,6 +897,29 @@ files = [
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
]
[[package]]
name = "claude-agent-sdk"
version = "0.1.35"
description = "Python SDK for Claude Code"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
]
[package.dependencies]
anyio = ">=4.0.0"
mcp = ">=0.1.0"
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
[[package]]
name = "cleo"
version = "2.1.0"
@@ -2593,6 +2616,18 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "httpx-sse"
version = "0.4.3"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
]
[[package]]
name = "huggingface-hub"
version = "1.4.1"
@@ -3310,6 +3345,39 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
name = "mcp"
version = "1.26.0"
description = "Model Context Protocol SDK"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
]
[package.dependencies]
anyio = ">=4.5"
httpx = ">=0.27.1"
httpx-sse = ">=0.4"
jsonschema = ">=4.20.0"
pydantic = ">=2.11.0,<3.0.0"
pydantic-settings = ">=2.5.2"
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
python-multipart = ">=0.0.9"
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
sse-starlette = ">=1.6.1"
starlette = ">=0.27"
typing-extensions = ">=4.9.0"
typing-inspection = ">=0.4.1"
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
[package.extras]
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
rich = ["rich (>=13.9.4)"]
ws = ["websockets (>=15.0.1)"]
[[package]]
name = "mdurl"
version = "0.1.2"
@@ -5994,7 +6062,7 @@ description = "Python for Window Extensions"
optional = false
python-versions = "*"
groups = ["main"]
markers = "platform_system == \"Windows\""
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
files = [
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
@@ -6974,6 +7042,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sse-starlette"
version = "3.2.0"
description = "SSE plugin for Starlette"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
]
[package.dependencies]
anyio = ">=4.7.0"
starlette = ">=0.49.1"
[package.extras]
daphne = ["daphne (>=4.2.0)"]
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
granian = ["granian (>=2.3.1)"]
uvicorn = ["uvicorn (>=0.34.0)"]
[[package]]
name = "stagehand"
version = "0.5.9"
@@ -8440,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f"
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"

View File

@@ -16,6 +16,7 @@ anthropic = "^0.79.0"
apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }
claude-agent-sdk = "^0.1.0"
click = "^8.2.0"
cryptography = "^46.0"
discord-py = "^2.5.2"

View File

@@ -1143,153 +1143,6 @@ enum APIKeyStatus {
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
///////////// LLM REGISTRY AND BILLING DATA /////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// LlmCostUnit: Defines how LLM MODEL costs are calculated (per run or per token).
// This is distinct from BlockCostType (in backend/data/block.py) which defines
// how BLOCK EXECUTION costs are calculated (per run, per byte, or per second).
// LlmCostUnit is for pricing individual LLM model API calls in the registry,
// while BlockCostType is for billing platform block executions.
enum LlmCostUnit {
RUN
TOKENS
}
model LlmModelCreator {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique // e.g., "openai", "anthropic", "meta"
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
description String?
websiteUrl String? // Link to creator's website
logoUrl String? // URL to creator's logo
metadata Json @default("{}")
Models LlmModel[]
}
model LlmProvider {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique
displayName String
description String?
defaultCredentialProvider String?
defaultCredentialId String?
defaultCredentialType String?
supportsTools Boolean @default(true)
supportsJsonOutput Boolean @default(true)
supportsReasoning Boolean @default(false)
supportsParallelTool Boolean @default(false)
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModel {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
slug String @unique
displayName String
description String?
providerId String
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
creatorId String?
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
contextWindow Int
maxOutputTokens Int?
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive
isEnabled Boolean @default(true)
isRecommended Boolean @default(false)
capabilities Json @default("{}")
metadata Json @default("{}")
Costs LlmModelCost[]
@@index([providerId, isEnabled])
@@index([creatorId])
@@index([slug])
}
model LlmModelCost {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
unit LlmCostUnit @default(RUN)
creditCost Int
credentialProvider String
credentialId String?
credentialType String?
currency String?
metadata Json @default("{}")
llmModelId String
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
@@unique([llmModelId, credentialProvider, unit])
@@index([llmModelId])
@@index([credentialProvider])
}
// Tracks model migrations for revert capability
// When a model is disabled with migration, we record which nodes were affected
// so they can be reverted when the original model is back online
model LlmModelMigration {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
sourceModelSlug String // The original model that was disabled
targetModelSlug String // The model workflows were migrated to
reason String? // Why the migration happened (e.g., "Provider outage")
// Track affected nodes as JSON array of node IDs
// Format: ["node-uuid-1", "node-uuid-2", ...]
migratedNodeIds Json @default("[]")
nodeCount Int // Number of nodes migrated
// Custom pricing override for migrated workflows during the migration period.
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
// to avoid billing surprises, or offer a discount during the transition.
//
// IMPORTANT: This field is intended for integration with the billing system.
// When billing calculates costs for nodes affected by this migration, it should
// check if customCreditCost is set and use it instead of the target model's cost.
// If null, the target model's normal cost applies.
//
// TODO: Integrate with billing system to apply this override during cost calculation.
customCreditCost Int?
// Revert tracking
isReverted Boolean @default(false)
revertedAt DateTime?
@@index([sourceModelSlug])
@@index([targetModelSlug])
@@index([isReverted])
}
////////////// OAUTH PROVIDER TABLES //////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,133 @@
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
They validate that the security hooks correctly block unauthorized paths,
tool access, and dangerous input patterns.
Note: Bash command validation was removed — the SDK built-in Bash tool is not in
allowed_tools, and the bash_exec MCP tool has kernel-level network isolation
(unshare --net) making command-level parsing unnecessary.
"""
from backend.api.features.chat.sdk.security_hooks import (
_validate_tool_access,
_validate_workspace_path,
)
SDK_CWD = "/tmp/copilot-test-session"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
def _reason(result: dict) -> str:
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
# ============================================================
# Workspace path validation (Read, Write, Edit, etc.)
# ============================================================
class TestWorkspacePathValidation:
def test_path_in_workspace(self):
result = _validate_workspace_path(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
)
assert not _is_denied(result)
def test_path_outside_workspace(self):
result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
assert _is_denied(result)
def test_tool_results_allowed(self):
result = _validate_workspace_path(
"Read",
{"file_path": "~/.claude/projects/abc/tool-results/out.txt"},
SDK_CWD,
)
assert not _is_denied(result)
def test_claude_settings_blocked(self):
result = _validate_workspace_path(
"Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD
)
assert _is_denied(result)
def test_claude_projects_without_tool_results(self):
result = _validate_workspace_path(
"Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD
)
assert _is_denied(result)
def test_no_path_allowed(self):
"""Glob/Grep without path defaults to cwd — should be allowed."""
result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD)
assert not _is_denied(result)
def test_path_traversal_with_dotdot(self):
result = _validate_workspace_path(
"Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD
)
assert _is_denied(result)
# ============================================================
# Tool access validation
# ============================================================
class TestToolAccessValidation:
def test_blocked_tools(self):
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"Tool '{tool}' should be blocked"
def test_bash_builtin_blocked(self):
"""SDK built-in Bash (capital) is blocked as defence-in-depth."""
result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD)
assert _is_denied(result)
assert "Bash" in _reason(result)
def test_workspace_tools_delegate(self):
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
)
assert not _is_denied(result)
def test_dangerous_pattern_blocked(self):
result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"})
assert _is_denied(result)
def test_safe_unknown_tool_allowed(self):
result = _validate_tool_access("SomeSafeTool", {"data": "hello world"})
assert not _is_denied(result)
# ============================================================
# Deny message quality (ntindle feedback)
# ============================================================
class TestDenyMessageClarity:
"""Deny messages must include [SECURITY] and 'cannot be bypassed'
so the model knows the restriction is enforced, not a suggestion."""
def test_blocked_tool_message(self):
reason = _reason(_validate_tool_access("bash", {}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
def test_bash_builtin_blocked_message(self):
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
def test_workspace_path_message(self):
reason = _reason(
_validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
)
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason

View File

@@ -0,0 +1,255 @@
"""Unit tests for JSONL transcript management utilities."""
import json
import os
from backend.api.features.chat.sdk.transcript import (
STRIPPABLE_TYPES,
read_transcript_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
)
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
# --- Fixtures ---
METADATA_LINE = {"type": "queue-operation", "subtype": "create"}
FILE_HISTORY = {"type": "file-history-snapshot", "files": []}
USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
ASST_MSG = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {"role": "assistant", "content": "hello"},
}
PROGRESS_ENTRY = {
"type": "progress",
"uuid": "p1",
"parentUuid": "u1",
"data": {"type": "bash_progress", "stdout": "running..."},
}
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
# --- read_transcript_file ---
class TestReadTranscriptFile:
def test_returns_content_for_valid_file(self, tmp_path):
path = tmp_path / "session.jsonl"
path.write_text(VALID_TRANSCRIPT)
result = read_transcript_file(str(path))
assert result is not None
assert "user" in result
def test_returns_none_for_missing_file(self):
assert read_transcript_file("/nonexistent/path.jsonl") is None
def test_returns_none_for_empty_path(self):
assert read_transcript_file("") is None
def test_returns_none_for_empty_file(self, tmp_path):
path = tmp_path / "empty.jsonl"
path.write_text("")
assert read_transcript_file(str(path)) is None
def test_returns_none_for_metadata_only(self, tmp_path):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
path = tmp_path / "meta.jsonl"
path.write_text(content)
assert read_transcript_file(str(path)) is None
def test_returns_none_for_invalid_json(self, tmp_path):
path = tmp_path / "bad.jsonl"
path.write_text("not json\n{}\n{}\n")
assert read_transcript_file(str(path)) is None
def test_no_size_limit(self, tmp_path):
"""Large files are accepted — bucket storage has no size limit."""
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
path = tmp_path / "big.jsonl"
path.write_text(content)
result = read_transcript_file(str(path))
assert result is not None
# --- write_transcript_to_tempfile ---
class TestWriteTranscriptToTempfile:
"""Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check."""
def test_writes_file_and_returns_path(self):
cwd = "/tmp/copilot-test-write"
try:
result = write_transcript_to_tempfile(
VALID_TRANSCRIPT, "sess-1234-abcd", cwd
)
assert result is not None
assert os.path.isfile(result)
assert result.endswith(".jsonl")
with open(result) as f:
assert f.read() == VALID_TRANSCRIPT
finally:
import shutil
shutil.rmtree(cwd, ignore_errors=True)
def test_creates_parent_directory(self):
cwd = "/tmp/copilot-test-mkdir"
try:
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
assert result is not None
assert os.path.isdir(cwd)
finally:
import shutil
shutil.rmtree(cwd, ignore_errors=True)
def test_uses_session_id_prefix(self):
cwd = "/tmp/copilot-test-prefix"
try:
result = write_transcript_to_tempfile(
VALID_TRANSCRIPT, "abcdef12-rest", cwd
)
assert result is not None
assert "abcdef12" in os.path.basename(result)
finally:
import shutil
shutil.rmtree(cwd, ignore_errors=True)
def test_rejects_cwd_outside_sandbox(self, tmp_path):
cwd = str(tmp_path / "not-copilot")
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
assert result is None
# --- validate_transcript ---
class TestValidateTranscript:
def test_valid_transcript(self):
assert validate_transcript(VALID_TRANSCRIPT) is True
def test_none_content(self):
assert validate_transcript(None) is False
def test_empty_content(self):
assert validate_transcript("") is False
def test_metadata_only(self):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
assert validate_transcript(content) is False
def test_user_only_no_assistant(self):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG)
assert validate_transcript(content) is False
def test_assistant_only_no_user(self):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
assert validate_transcript(content) is False
def test_invalid_json_returns_false(self):
assert validate_transcript("not json\n{}\n{}\n") is False
# --- strip_progress_entries ---
class TestStripProgressEntries:
def test_strips_all_strippable_types(self):
"""All STRIPPABLE_TYPES are removed from the output."""
entries = [
USER_MSG,
{"type": "progress", "uuid": "p1", "parentUuid": "u1"},
{"type": "file-history-snapshot", "files": []},
{"type": "queue-operation", "subtype": "create"},
{"type": "summary", "text": "..."},
{"type": "pr-link", "url": "..."},
ASST_MSG,
]
result = strip_progress_entries(_make_jsonl(*entries))
result_types = {json.loads(line)["type"] for line in result.strip().split("\n")}
assert result_types == {"user", "assistant"}
for stype in STRIPPABLE_TYPES:
assert stype not in result_types
def test_reparents_children_of_stripped_entries(self):
"""An assistant message whose parent is a progress entry gets reparented."""
progress = {
"type": "progress",
"uuid": "p1",
"parentUuid": "u1",
"data": {"type": "bash_progress"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p1", # Points to progress
"message": {"role": "assistant", "content": "done"},
}
content = _make_jsonl(USER_MSG, progress, asst)
result = strip_progress_entries(content)
lines = [json.loads(line) for line in result.strip().split("\n")]
asst_entry = next(e for e in lines if e["type"] == "assistant")
# Should be reparented to u1 (the user message)
assert asst_entry["parentUuid"] == "u1"
def test_reparents_through_chain(self):
"""Reparenting walks through multiple stripped entries."""
p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"}
p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p3", # 3 levels deep
"message": {"role": "assistant", "content": "done"},
}
content = _make_jsonl(USER_MSG, p1, p2, p3, asst)
result = strip_progress_entries(content)
lines = [json.loads(line) for line in result.strip().split("\n")]
asst_entry = next(e for e in lines if e["type"] == "assistant")
assert asst_entry["parentUuid"] == "u1"
def test_preserves_non_strippable_entries(self):
"""User, assistant, and system entries are preserved."""
system = {"type": "system", "uuid": "s1", "message": "prompt"}
content = _make_jsonl(system, USER_MSG, ASST_MSG)
result = strip_progress_entries(content)
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
assert result_types == ["system", "user", "assistant"]
def test_empty_input(self):
result = strip_progress_entries("")
# Should return just a newline (empty content stripped)
assert result.strip() == ""
def test_no_strippable_entries(self):
"""When there's nothing to strip, output matches input structure."""
content = _make_jsonl(USER_MSG, ASST_MSG)
result = strip_progress_entries(content)
result_lines = result.strip().split("\n")
assert len(result_lines) == 2
def test_handles_entries_without_uuid(self):
"""Entries without uuid field are handled gracefully."""
no_uuid = {"type": "queue-operation", "subtype": "create"}
content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG)
result = strip_progress_entries(content)
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
# queue-operation is strippable
assert "queue-operation" not in result_types
assert "user" in result_types
assert "assistant" in result_types

View File

@@ -1,8 +1,5 @@
"use client";
import { Sidebar } from "@/components/__legacy__/Sidebar";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import { Cpu } from "@phosphor-icons/react";
import { IconSliders } from "@/components/__legacy__/ui/icons";
@@ -29,11 +26,6 @@ const sidebarLinkGroups = [
href: "/admin/execution-analytics",
icon: <FileText className="h-6 w-6" />,
},
{
text: "LLM Registry",
href: "/admin/llms",
icon: <Cpu size={24} />,
},
{
text: "Admin User Management",
href: "/admin/settings",

View File

@@ -1,493 +0,0 @@
"use server";
import { revalidatePath } from "next/cache";
// Generated API functions
import {
getV2ListLlmProviders,
postV2CreateLlmProvider,
patchV2UpdateLlmProvider,
deleteV2DeleteLlmProvider,
getV2ListLlmModels,
postV2CreateLlmModel,
patchV2UpdateLlmModel,
patchV2ToggleLlmModelAvailability,
deleteV2DeleteLlmModelAndMigrateWorkflows,
getV2GetModelUsageCount,
getV2ListModelMigrations,
postV2RevertAModelMigration,
getV2ListModelCreators,
postV2CreateModelCreator,
patchV2UpdateModelCreator,
deleteV2DeleteModelCreator,
postV2SetRecommendedModel,
} from "@/app/api/__generated__/endpoints/admin/admin";
// Generated types
import type { LlmProvidersResponse } from "@/app/api/__generated__/models/llmProvidersResponse";
import type { LlmModelsResponse } from "@/app/api/__generated__/models/llmModelsResponse";
import type { UpsertLlmProviderRequest } from "@/app/api/__generated__/models/upsertLlmProviderRequest";
import type { CreateLlmModelRequest } from "@/app/api/__generated__/models/createLlmModelRequest";
import type { UpdateLlmModelRequest } from "@/app/api/__generated__/models/updateLlmModelRequest";
import type { ToggleLlmModelRequest } from "@/app/api/__generated__/models/toggleLlmModelRequest";
import type { LlmMigrationsResponse } from "@/app/api/__generated__/models/llmMigrationsResponse";
import type { LlmCreatorsResponse } from "@/app/api/__generated__/models/llmCreatorsResponse";
import type { UpsertLlmCreatorRequest } from "@/app/api/__generated__/models/upsertLlmCreatorRequest";
import type { LlmModelUsageResponse } from "@/app/api/__generated__/models/llmModelUsageResponse";
import { LlmCostUnit } from "@/app/api/__generated__/models/llmCostUnit";
const ADMIN_LLM_PATH = "/admin/llms";
// =============================================================================
// Utilities
// =============================================================================
/**
* Extracts and validates a required string field from FormData.
* Throws an error if the field is missing or empty.
*/
function getRequiredFormField(
formData: FormData,
fieldName: string,
displayName?: string,
): string {
const raw = formData.get(fieldName);
const value = raw ? String(raw).trim() : "";
if (!value) {
throw new Error(`${displayName || fieldName} is required`);
}
return value;
}
/**
* Extracts and validates a required positive number field from FormData.
* Throws an error if the field is missing, empty, or not a positive number.
*/
function getRequiredPositiveNumber(
formData: FormData,
fieldName: string,
displayName?: string,
): number {
const raw = formData.get(fieldName);
const value = Number(raw);
if (raw === null || raw === "" || !Number.isFinite(value) || value <= 0) {
throw new Error(`${displayName || fieldName} must be a positive number`);
}
return value;
}
/**
* Extracts and validates a required number field from FormData.
* Throws an error if the field is missing, empty, or not a finite number.
*/
function getRequiredNumber(
formData: FormData,
fieldName: string,
displayName?: string,
): number {
const raw = formData.get(fieldName);
const value = Number(raw);
if (raw === null || raw === "" || !Number.isFinite(value)) {
throw new Error(`${displayName || fieldName} is required`);
}
return value;
}
// =============================================================================
// Provider Actions
// =============================================================================
export async function fetchLlmProviders(): Promise<LlmProvidersResponse> {
const response = await getV2ListLlmProviders({ include_models: true });
if (response.status !== 200) {
throw new Error("Failed to fetch LLM providers");
}
return response.data;
}
export async function createLlmProviderAction(formData: FormData) {
const payload: UpsertLlmProviderRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
default_credential_provider: formData.get("default_credential_provider")
? String(formData.get("default_credential_provider")).trim()
: undefined,
default_credential_id: formData.get("default_credential_id")
? String(formData.get("default_credential_id")).trim()
: undefined,
default_credential_type: formData.get("default_credential_type")
? String(formData.get("default_credential_type")).trim()
: "api_key",
supports_tools: formData.getAll("supports_tools").includes("on"),
supports_json_output: formData
.getAll("supports_json_output")
.includes("on"),
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
supports_parallel_tool: formData
.getAll("supports_parallel_tool")
.includes("on"),
metadata: {},
};
const response = await postV2CreateLlmProvider(payload);
if (response.status !== 200) {
throw new Error("Failed to create LLM provider");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmProviderAction(
formData: FormData,
): Promise<void> {
const providerId = getRequiredFormField(
formData,
"provider_id",
"Provider id",
);
const response = await deleteV2DeleteLlmProvider(providerId);
if (response.status !== 200) {
const errorData = response.data as { detail?: string };
throw new Error(errorData?.detail || "Failed to delete provider");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmProviderAction(formData: FormData) {
const providerId = getRequiredFormField(
formData,
"provider_id",
"Provider id",
);
const payload: UpsertLlmProviderRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
default_credential_provider: formData.get("default_credential_provider")
? String(formData.get("default_credential_provider")).trim()
: undefined,
default_credential_id: formData.get("default_credential_id")
? String(formData.get("default_credential_id")).trim()
: undefined,
default_credential_type: formData.get("default_credential_type")
? String(formData.get("default_credential_type")).trim()
: "api_key",
supports_tools: formData.getAll("supports_tools").includes("on"),
supports_json_output: formData
.getAll("supports_json_output")
.includes("on"),
supports_reasoning: formData.getAll("supports_reasoning").includes("on"),
supports_parallel_tool: formData
.getAll("supports_parallel_tool")
.includes("on"),
metadata: {},
};
const response = await patchV2UpdateLlmProvider(providerId, payload);
if (response.status !== 200) {
throw new Error("Failed to update LLM provider");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Model Actions
// =============================================================================
export async function fetchLlmModels(): Promise<LlmModelsResponse> {
const response = await getV2ListLlmModels();
if (response.status !== 200) {
throw new Error("Failed to fetch LLM models");
}
return response.data;
}
export async function createLlmModelAction(formData: FormData) {
const providerId = getRequiredFormField(formData, "provider_id", "Provider");
const creatorId = formData.get("creator_id");
const contextWindow = getRequiredPositiveNumber(
formData,
"context_window",
"Context window",
);
const creditCost = getRequiredNumber(formData, "credit_cost", "Credit cost");
// Fetch provider to get default credentials
const providersResponse = await getV2ListLlmProviders({
include_models: false,
});
if (providersResponse.status !== 200) {
throw new Error("Failed to fetch providers");
}
const provider = providersResponse.data.providers.find(
(p) => p.id === providerId,
);
if (!provider) {
throw new Error("Provider not found");
}
const payload: CreateLlmModelRequest = {
slug: String(formData.get("slug") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
provider_id: providerId,
creator_id: creatorId ? String(creatorId) : undefined,
context_window: contextWindow,
max_output_tokens: formData.get("max_output_tokens")
? Number(formData.get("max_output_tokens"))
: undefined,
is_enabled: formData.getAll("is_enabled").includes("on"),
capabilities: {},
metadata: {},
costs: [
{
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
credit_cost: creditCost,
credential_provider:
provider.default_credential_provider || provider.name,
credential_id: provider.default_credential_id || undefined,
credential_type: provider.default_credential_type || "api_key",
metadata: {},
},
],
};
const response = await postV2CreateLlmModel(payload);
if (response.status !== 200) {
throw new Error("Failed to create LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmModelAction(formData: FormData) {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const creatorId = formData.get("creator_id");
const payload: UpdateLlmModelRequest = {
display_name: formData.get("display_name")
? String(formData.get("display_name"))
: undefined,
description: formData.get("description")
? String(formData.get("description"))
: undefined,
provider_id: formData.get("provider_id")
? String(formData.get("provider_id"))
: undefined,
creator_id: creatorId ? String(creatorId) : undefined,
context_window: formData.get("context_window")
? Number(formData.get("context_window"))
: undefined,
max_output_tokens: formData.get("max_output_tokens")
? Number(formData.get("max_output_tokens"))
: undefined,
is_enabled: formData.has("is_enabled")
? formData.getAll("is_enabled").includes("on")
: undefined,
costs: formData.get("credit_cost")
? [
{
unit: (formData.get("unit") as LlmCostUnit) || LlmCostUnit.RUN,
credit_cost: Number(formData.get("credit_cost")),
credential_provider: String(
formData.get("credential_provider") || "",
).trim(),
credential_id: formData.get("credential_id")
? String(formData.get("credential_id"))
: undefined,
credential_type: formData.get("credential_type")
? String(formData.get("credential_type"))
: undefined,
metadata: {},
},
]
: undefined,
};
const response = await patchV2UpdateLlmModel(modelId, payload);
if (response.status !== 200) {
throw new Error("Failed to update LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function toggleLlmModelAction(formData: FormData): Promise<void> {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const shouldEnable = formData.get("is_enabled") === "true";
const migrateToSlug = formData.get("migrate_to_slug");
const migrationReason = formData.get("migration_reason");
const customCreditCost = formData.get("custom_credit_cost");
const payload: ToggleLlmModelRequest = {
is_enabled: shouldEnable,
migrate_to_slug: migrateToSlug ? String(migrateToSlug) : undefined,
migration_reason: migrationReason ? String(migrationReason) : undefined,
custom_credit_cost: customCreditCost ? Number(customCreditCost) : undefined,
};
const response = await patchV2ToggleLlmModelAvailability(modelId, payload);
if (response.status !== 200) {
throw new Error("Failed to toggle LLM model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmModelAction(formData: FormData): Promise<void> {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const rawReplacement = formData.get("replacement_model_slug");
const replacementModelSlug =
rawReplacement && String(rawReplacement).trim()
? String(rawReplacement).trim()
: undefined;
const response = await deleteV2DeleteLlmModelAndMigrateWorkflows(modelId, {
replacement_model_slug: replacementModelSlug,
});
if (response.status !== 200) {
throw new Error("Failed to delete model");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function fetchLlmModelUsage(
modelId: string,
): Promise<LlmModelUsageResponse> {
const response = await getV2GetModelUsageCount(modelId);
if (response.status !== 200) {
throw new Error("Failed to fetch model usage");
}
return response.data;
}
// =============================================================================
// Migration Actions
// =============================================================================
export async function fetchLlmMigrations(
includeReverted: boolean = false,
): Promise<LlmMigrationsResponse> {
const response = await getV2ListModelMigrations({
include_reverted: includeReverted,
});
if (response.status !== 200) {
throw new Error("Failed to fetch migrations");
}
return response.data;
}
export async function revertLlmMigrationAction(
formData: FormData,
): Promise<void> {
const migrationId = getRequiredFormField(
formData,
"migration_id",
"Migration id",
);
const response = await postV2RevertAModelMigration(migrationId, null);
if (response.status !== 200) {
throw new Error("Failed to revert migration");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Creator Actions
// =============================================================================
export async function fetchLlmCreators(): Promise<LlmCreatorsResponse> {
const response = await getV2ListModelCreators();
if (response.status !== 200) {
throw new Error("Failed to fetch creators");
}
return response.data;
}
export async function createLlmCreatorAction(
formData: FormData,
): Promise<void> {
const payload: UpsertLlmCreatorRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
website_url: formData.get("website_url")
? String(formData.get("website_url")).trim()
: undefined,
logo_url: formData.get("logo_url")
? String(formData.get("logo_url")).trim()
: undefined,
metadata: {},
};
const response = await postV2CreateModelCreator(payload);
if (response.status !== 200) {
throw new Error("Failed to create creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function updateLlmCreatorAction(
formData: FormData,
): Promise<void> {
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
const payload: UpsertLlmCreatorRequest = {
name: String(formData.get("name") || "").trim(),
display_name: String(formData.get("display_name") || "").trim(),
description: formData.get("description")
? String(formData.get("description"))
: undefined,
website_url: formData.get("website_url")
? String(formData.get("website_url")).trim()
: undefined,
logo_url: formData.get("logo_url")
? String(formData.get("logo_url")).trim()
: undefined,
metadata: {},
};
const response = await patchV2UpdateModelCreator(creatorId, payload);
if (response.status !== 200) {
throw new Error("Failed to update creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
export async function deleteLlmCreatorAction(
formData: FormData,
): Promise<void> {
const creatorId = getRequiredFormField(formData, "creator_id", "Creator id");
const response = await deleteV2DeleteModelCreator(creatorId);
if (response.status !== 200) {
throw new Error("Failed to delete creator");
}
revalidatePath(ADMIN_LLM_PATH);
}
// =============================================================================
// Recommended Model Actions
// =============================================================================
export async function setRecommendedModelAction(
formData: FormData,
): Promise<void> {
const modelId = getRequiredFormField(formData, "model_id", "Model id");
const response = await postV2SetRecommendedModel({ model_id: modelId });
if (response.status !== 200) {
throw new Error("Failed to set recommended model");
}
revalidatePath(ADMIN_LLM_PATH);
}

View File

@@ -1,147 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { createLlmCreatorAction } from "../actions";
import { useRouter } from "next/navigation";
export function AddCreatorModal() {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to create creator");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Add Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "512px" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Creator
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Add a new model creator (the organization that made/trained the
model).
</div>
<form action={handleSubmit} className="space-y-4">
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Name (slug) <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
<p className="text-xs text-muted-foreground">
Lowercase identifier (e.g., openai, meta, anthropic)
</p>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={2}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Creator of GPT models..."
/>
</div>
<div className="space-y-2">
<label
htmlFor="website_url"
className="text-sm font-medium text-foreground"
>
Website URL
</label>
<input
id="website_url"
name="website_url"
type="url"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="https://openai.com"
/>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Add Creator"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,314 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import { createLlmModelAction } from "../actions";
import { useRouter } from "next/navigation";
interface Props {
providers: LlmProvider[];
creators: LlmModelCreator[];
}
export function AddModelModal({ providers, creators }: Props) {
const [open, setOpen] = useState(false);
const [selectedCreatorId, setSelectedCreatorId] = useState("");
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to create model");
} finally {
setIsSubmitting(false);
}
}
// When provider changes, auto-select matching creator if one exists
function handleProviderChange(providerId: string) {
const provider = providers.find((p) => p.id === providerId);
if (provider) {
// Find creator with same name as provider (e.g., "openai" -> "openai")
const matchingCreator = creators.find((c) => c.name === provider.name);
if (matchingCreator) {
setSelectedCreatorId(matchingCreator.id);
} else {
// No matching creator (e.g., OpenRouter hosts other creators' models)
setSelectedCreatorId("");
}
}
}
return (
<Dialog
title="Add Model"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Model
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Register a new model slug, metadata, and pricing.
</div>
<form action={handleSubmit} className="space-y-6">
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core model details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="slug"
className="text-sm font-medium text-foreground"
>
Model Slug <span className="text-destructive">*</span>
</label>
<input
id="slug"
required
name="slug"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="gpt-4.1-mini-2025-04-14"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="GPT 4.1 Mini"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Model Configuration */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Model Configuration
</h3>
<p className="text-xs text-muted-foreground">
Model capabilities and limits
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="provider_id"
className="text-sm font-medium text-foreground"
>
Provider <span className="text-destructive">*</span>
</label>
<select
id="provider_id"
required
name="provider_id"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
defaultValue=""
onChange={(e) => handleProviderChange(e.target.value)}
>
<option value="" disabled>
Select provider
</option>
{providers.map((provider) => (
<option key={provider.id} value={provider.id}>
{provider.display_name} ({provider.name})
</option>
))}
</select>
<p className="text-xs text-muted-foreground">
Who hosts/serves the model
</p>
</div>
<div className="space-y-2">
<label
htmlFor="creator_id"
className="text-sm font-medium text-foreground"
>
Creator
</label>
<select
id="creator_id"
name="creator_id"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
value={selectedCreatorId}
onChange={(e) => setSelectedCreatorId(e.target.value)}
>
<option value="">No creator selected</option>
{creators.map((creator) => (
<option key={creator.id} value={creator.id}>
{creator.display_name} ({creator.name})
</option>
))}
</select>
<p className="text-xs text-muted-foreground">
Who made/trained the model (e.g., OpenAI, Meta)
</p>
</div>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="context_window"
className="text-sm font-medium text-foreground"
>
Context Window <span className="text-destructive">*</span>
</label>
<input
id="context_window"
required
type="number"
name="context_window"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="128000"
min={1}
/>
</div>
<div className="space-y-2">
<label
htmlFor="max_output_tokens"
className="text-sm font-medium text-foreground"
>
Max Output Tokens
</label>
<input
id="max_output_tokens"
type="number"
name="max_output_tokens"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="16384"
min={1}
/>
</div>
</div>
</div>
{/* Pricing */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">Pricing</h3>
<p className="text-xs text-muted-foreground">
Credit cost per run (credentials are managed via the provider)
</p>
</div>
<div className="grid gap-4 sm:grid-cols-1">
<div className="space-y-2">
<label
htmlFor="credit_cost"
className="text-sm font-medium text-foreground"
>
Credit Cost <span className="text-destructive">*</span>
</label>
<input
id="credit_cost"
required
type="number"
name="credit_cost"
step="1"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="5"
min={0}
/>
</div>
</div>
<p className="text-xs text-muted-foreground">
Credit cost is always in platform credits. Credentials are
inherited from the selected provider.
</p>
</div>
{/* Enabled Toggle */}
<div className="flex items-center gap-3 border-t border-border pt-6">
<input type="hidden" name="is_enabled" value="off" />
<input
id="is_enabled"
type="checkbox"
name="is_enabled"
defaultChecked
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor="is_enabled"
className="text-sm font-medium text-foreground"
>
Enabled by default
</label>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Save Model"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,268 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { createLlmProviderAction } from "../actions";
import { useRouter } from "next/navigation";
export function AddProviderModal() {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await createLlmProviderAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to create provider",
);
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Add Provider"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="primary" size="small">
Add Provider
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Define a new upstream provider and default credential information.
</div>
{/* Setup Instructions */}
<div className="mb-6 rounded-lg border border-primary/30 bg-primary/5 p-4">
<div className="space-y-2">
<h4 className="text-sm font-semibold text-foreground">
Before Adding a Provider
</h4>
<p className="text-xs text-muted-foreground">
To use a new provider, you must first configure its credentials in
the backend:
</p>
<ol className="list-inside list-decimal space-y-1 text-xs text-muted-foreground">
<li>
Add the credential to{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
backend/integrations/credentials_store.py
</code>{" "}
with a UUID, provider name, and settings secret reference
</li>
<li>
Add it to the{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
PROVIDER_CREDENTIALS
</code>{" "}
dictionary in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
backend/data/block_cost_config.py
</code>
</li>
<li>
Use the <strong>same provider name</strong> in the
&quot;Credential Provider&quot; field below that matches the key
in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono">
PROVIDER_CREDENTIALS
</code>
</li>
</ol>
</div>
</div>
<form action={handleSubmit} className="space-y-6">
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core provider details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Provider Slug <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="e.g. openai"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Default Credentials */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Default Credentials
</h3>
<p className="text-xs text-muted-foreground">
Credential provider name that matches the key in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>
</p>
</div>
<div className="space-y-2">
<label
htmlFor="default_credential_provider"
className="text-sm font-medium text-foreground"
>
Credential Provider <span className="text-destructive">*</span>
</label>
<input
id="default_credential_provider"
name="default_credential_provider"
required
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
<p className="text-xs text-muted-foreground">
<strong>Important:</strong> This must exactly match the key in
the{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>{" "}
dictionary in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
block_cost_config.py
</code>
. Common values: &quot;openai&quot;, &quot;anthropic&quot;,
&quot;groq&quot;, &quot;open_router&quot;, etc.
</p>
</div>
</div>
{/* Capabilities */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Capabilities
</h3>
<p className="text-xs text-muted-foreground">
Provider feature flags
</p>
</div>
<div className="grid gap-3 sm:grid-cols-2">
{[
{ name: "supports_tools", label: "Supports tools" },
{ name: "supports_json_output", label: "Supports JSON output" },
{ name: "supports_reasoning", label: "Supports reasoning" },
{
name: "supports_parallel_tool",
label: "Supports parallel tool calls",
},
].map(({ name, label }) => (
<div
key={name}
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
>
<input type="hidden" name={name} value="off" />
<input
id={name}
type="checkbox"
name={name}
defaultChecked={
name !== "supports_reasoning" &&
name !== "supports_parallel_tool"
}
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor={name}
className="text-sm font-medium text-foreground"
>
{label}
</label>
</div>
))}
</div>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Creating..." : "Save Provider"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,195 +0,0 @@
"use client";
import { useState } from "react";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { Button } from "@/components/atoms/Button/Button";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { updateLlmCreatorAction } from "../actions";
import { useRouter } from "next/navigation";
import { DeleteCreatorModal } from "./DeleteCreatorModal";
export function CreatorsTable({ creators }: { creators: LlmModelCreator[] }) {
if (!creators.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No creators registered yet.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Creator</TableHead>
<TableHead>Description</TableHead>
<TableHead>Website</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{creators.map((creator) => (
<TableRow key={creator.id}>
<TableCell>
<div className="font-medium">{creator.display_name}</div>
<div className="text-xs text-muted-foreground">
{creator.name}
</div>
</TableCell>
<TableCell>
<span className="text-sm text-muted-foreground">
{creator.description || "—"}
</span>
</TableCell>
<TableCell>
{creator.website_url ? (
<a
href={creator.website_url}
target="_blank"
rel="noopener noreferrer"
className="text-sm text-primary hover:underline"
>
{(() => {
try {
return new URL(creator.website_url).hostname;
} catch {
return creator.website_url;
}
})()}
</a>
) : (
<span className="text-muted-foreground"></span>
)}
</TableCell>
<TableCell>
<div className="flex items-center justify-end gap-2">
<EditCreatorModal creator={creator} />
<DeleteCreatorModal creator={creator} />
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
);
}
function EditCreatorModal({ creator }: { creator: LlmModelCreator }) {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to update creator");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "512px" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small" className="min-w-0">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<form action={handleSubmit} className="space-y-4">
<input type="hidden" name="creator_id" value={creator.id} />
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label className="text-sm font-medium">Name (slug)</label>
<input
required
name="name"
defaultValue={creator.name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Display Name</label>
<input
required
name="display_name"
defaultValue={creator.display_name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Description</label>
<textarea
name="description"
rows={2}
defaultValue={creator.description ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Website URL</label>
<input
name="website_url"
type="url"
defaultValue={creator.website_url ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Updating..." : "Update"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,107 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import { deleteLlmCreatorAction } from "../actions";
export function DeleteCreatorModal({ creator }: { creator: LlmModelCreator }) {
const [open, setOpen] = useState(false);
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmCreatorAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to delete creator");
} finally {
setIsDeleting(false);
}
}
return (
<Dialog
title="Delete Creator"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "480px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{creator.display_name}</span>{" "}
<span className="text-muted-foreground">
({creator.name})
</span>
</p>
<p className="mt-2 text-muted-foreground">
Models using this creator will have their creator field
cleared. This is safe and won&apos;t affect model
functionality.
</p>
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="creator_id" value={creator.id} />
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isDeleting}
type="button"
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={isDeleting}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
{isDeleting ? "Deleting..." : "Delete Creator"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,224 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { deleteLlmModelAction, fetchLlmModelUsage } from "../actions";
export function DeleteModelModal({
model,
availableModels,
}: {
model: LlmModel;
availableModels: LlmModel[];
}) {
const router = useRouter();
const [open, setOpen] = useState(false);
const [selectedReplacement, setSelectedReplacement] = useState<string>("");
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const [usageCount, setUsageCount] = useState<number | null>(null);
const [usageLoading, setUsageLoading] = useState(false);
const [usageError, setUsageError] = useState<string | null>(null);
// Filter out the current model and disabled models from replacement options
const replacementOptions = availableModels.filter(
(m) => m.id !== model.id && m.is_enabled,
);
// Check if migration is required (has blocks using this model)
const requiresMigration = usageCount !== null && usageCount > 0;
async function fetchUsage() {
setUsageLoading(true);
setUsageError(null);
try {
const usage = await fetchLlmModelUsage(model.id);
setUsageCount(usage.node_count);
} catch (err) {
console.error("Failed to fetch model usage:", err);
setUsageError("Failed to load usage count");
setUsageCount(null);
} finally {
setUsageLoading(false);
}
}
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to delete model");
} finally {
setIsDeleting(false);
}
}
// Determine if delete button should be enabled
const canDelete =
!isDeleting &&
!usageLoading &&
usageCount !== null &&
(requiresMigration
? selectedReplacement && replacementOptions.length > 0
: true);
return (
<Dialog
title="Delete Model"
controlled={{
isOpen: open,
set: async (isOpen) => {
setOpen(isOpen);
if (isOpen) {
setUsageCount(null);
setUsageError(null);
setError(null);
setSelectedReplacement("");
await fetchUsage();
}
},
}}
styling={{ maxWidth: "600px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
{requiresMigration
? "This action cannot be undone. All workflows using this model will be migrated to the replacement model you select."
: "This action cannot be undone."}
</div>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{model.display_name}</span>{" "}
<span className="text-muted-foreground">({model.slug})</span>
</p>
{usageLoading && (
<p className="mt-2 text-muted-foreground">
Loading usage count...
</p>
)}
{usageError && (
<p className="mt-2 text-destructive">{usageError}</p>
)}
{!usageLoading && !usageError && usageCount !== null && (
<p className="mt-2 font-semibold">
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
currently use this model
</p>
)}
{requiresMigration && (
<p className="mt-2 text-muted-foreground">
All workflows currently using this model will be
automatically updated to use the replacement model you
choose below.
</p>
)}
{!usageLoading && usageCount === 0 && (
<p className="mt-2 text-muted-foreground">
No workflows are using this model. It can be safely deleted.
</p>
)}
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<input
type="hidden"
name="replacement_model_slug"
value={selectedReplacement}
/>
{requiresMigration && (
<label className="text-sm font-medium">
<span className="mb-2 block">
Select Replacement Model{" "}
<span className="text-destructive">*</span>
</span>
<select
required
value={selectedReplacement}
onChange={(e) => setSelectedReplacement(e.target.value)}
className="w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="">-- Choose a replacement model --</option>
{replacementOptions.map((m) => (
<option key={m.id} value={m.slug}>
{m.display_name} ({m.slug})
</option>
))}
</select>
{replacementOptions.length === 0 && (
<p className="mt-2 text-xs text-destructive">
No replacement models available. You must have at least one
other enabled model before deleting this one.
</p>
)}
</label>
)}
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setSelectedReplacement("");
setError(null);
}}
disabled={isDeleting}
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={!canDelete}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
{isDeleting
? "Deleting..."
: requiresMigration
? "Delete and Migrate"
: "Delete"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,129 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { deleteLlmProviderAction } from "../actions";
export function DeleteProviderModal({ provider }: { provider: LlmProvider }) {
const [open, setOpen] = useState(false);
const [isDeleting, setIsDeleting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
const modelCount = provider.models?.length ?? 0;
const hasModels = modelCount > 0;
async function handleDelete(formData: FormData) {
setIsDeleting(true);
setError(null);
try {
await deleteLlmProviderAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to delete provider",
);
} finally {
setIsDeleting(false);
}
}
return (
<Dialog
title="Delete Provider"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "480px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0 text-destructive hover:bg-destructive/10"
>
Delete
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="space-y-4">
<div
className={`rounded-lg border p-4 ${
hasModels
? "border-destructive/30 bg-destructive/10"
: "border-amber-500/30 bg-amber-500/10 dark:border-amber-400/30 dark:bg-amber-400/10"
}`}
>
<div className="flex items-start gap-3">
<div
className={`flex-shrink-0 ${
hasModels
? "text-destructive"
: "text-amber-600 dark:text-amber-400"
}`}
>
{hasModels ? "🚫" : "⚠️"}
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to delete:</p>
<p className="mt-1">
<span className="font-medium">{provider.display_name}</span>{" "}
<span className="text-muted-foreground">
({provider.name})
</span>
</p>
{hasModels ? (
<p className="mt-2 text-destructive">
This provider has {modelCount} model(s). You must delete all
models before you can delete this provider.
</p>
) : (
<p className="mt-2 text-muted-foreground">
This provider has no models and can be safely deleted.
</p>
)}
</div>
</div>
</div>
<form action={handleDelete} className="space-y-4">
<input type="hidden" name="provider_id" value={provider.id} />
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isDeleting}
type="button"
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={isDeleting || hasModels}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90 disabled:opacity-50"
>
{isDeleting ? "Deleting..." : "Delete Provider"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,288 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { toggleLlmModelAction, fetchLlmModelUsage } from "../actions";
export function DisableModelModal({
model,
availableModels,
}: {
model: LlmModel;
availableModels: LlmModel[];
}) {
const [open, setOpen] = useState(false);
const [isDisabling, setIsDisabling] = useState(false);
const [error, setError] = useState<string | null>(null);
const [usageCount, setUsageCount] = useState<number | null>(null);
const [selectedMigration, setSelectedMigration] = useState<string>("");
const [wantsMigration, setWantsMigration] = useState(false);
const [migrationReason, setMigrationReason] = useState("");
const [customCreditCost, setCustomCreditCost] = useState<string>("");
// Filter out the current model and disabled models from replacement options
const migrationOptions = availableModels.filter(
(m) => m.id !== model.id && m.is_enabled,
);
async function fetchUsage() {
try {
const usage = await fetchLlmModelUsage(model.id);
setUsageCount(usage.node_count);
} catch {
setUsageCount(null);
}
}
async function handleDisable(formData: FormData) {
setIsDisabling(true);
setError(null);
try {
await toggleLlmModelAction(formData);
setOpen(false);
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to disable model");
} finally {
setIsDisabling(false);
}
}
function resetState() {
setError(null);
setSelectedMigration("");
setWantsMigration(false);
setMigrationReason("");
setCustomCreditCost("");
}
const hasUsage = usageCount !== null && usageCount > 0;
return (
<Dialog
title="Disable Model"
controlled={{
isOpen: open,
set: async (isOpen) => {
setOpen(isOpen);
if (isOpen) {
setUsageCount(null);
resetState();
await fetchUsage();
}
},
}}
styling={{ maxWidth: "600px" }}
>
<Dialog.Trigger>
<Button
type="button"
variant="outline"
size="small"
className="min-w-0"
>
Disable
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Disabling a model will hide it from users when creating new workflows.
</div>
<div className="space-y-4">
<div className="rounded-lg border border-amber-500/30 bg-amber-500/10 p-4 dark:border-amber-400/30 dark:bg-amber-400/10">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 text-amber-600 dark:text-amber-400">
</div>
<div className="text-sm text-foreground">
<p className="font-semibold">You are about to disable:</p>
<p className="mt-1">
<span className="font-medium">{model.display_name}</span>{" "}
<span className="text-muted-foreground">({model.slug})</span>
</p>
{usageCount === null ? (
<p className="mt-2 text-muted-foreground">
Loading usage data...
</p>
) : usageCount > 0 ? (
<p className="mt-2 font-semibold">
Impact: {usageCount} block{usageCount !== 1 ? "s" : ""}{" "}
currently use this model
</p>
) : (
<p className="mt-2 text-muted-foreground">
No workflows are currently using this model.
</p>
)}
</div>
</div>
</div>
{hasUsage && (
<div className="space-y-4 rounded-lg border border-border bg-muted/50 p-4">
<label className="flex items-start gap-3">
<input
type="checkbox"
checked={wantsMigration}
onChange={(e) => {
setWantsMigration(e.target.checked);
if (!e.target.checked) {
setSelectedMigration("");
}
}}
className="mt-1"
/>
<div className="text-sm">
<span className="font-medium">
Migrate existing workflows to another model
</span>
<p className="mt-1 text-muted-foreground">
Creates a revertible migration record. If unchecked,
existing workflows will use automatic fallback to an enabled
model from the same provider.
</p>
</div>
</label>
{wantsMigration && (
<div className="space-y-4 border-t border-border pt-4">
<label className="block text-sm font-medium">
<span className="mb-2 block">
Replacement Model{" "}
<span className="text-destructive">*</span>
</span>
<select
required
value={selectedMigration}
onChange={(e) => setSelectedMigration(e.target.value)}
className="w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="">-- Choose a replacement model --</option>
{migrationOptions.map((m) => (
<option key={m.id} value={m.slug}>
{m.display_name} ({m.slug})
</option>
))}
</select>
{migrationOptions.length === 0 && (
<p className="mt-2 text-xs text-destructive">
No other enabled models available for migration.
</p>
)}
</label>
<label className="block text-sm font-medium">
<span className="mb-2 block">
Migration Reason{" "}
<span className="font-normal text-muted-foreground">
(optional)
</span>
</span>
<input
type="text"
value={migrationReason}
onChange={(e) => setMigrationReason(e.target.value)}
placeholder="e.g., Provider outage, Cost reduction"
className="w-full rounded border border-input bg-background p-2 text-sm"
/>
<p className="mt-1 text-xs text-muted-foreground">
Helps track why the migration was made
</p>
</label>
<label className="block text-sm font-medium">
<span className="mb-2 block">
Custom Credit Cost{" "}
<span className="font-normal text-muted-foreground">
(optional)
</span>
</span>
<input
type="number"
min="0"
value={customCreditCost}
onChange={(e) => setCustomCreditCost(e.target.value)}
placeholder="Leave blank to use target model's cost"
className="w-full rounded border border-input bg-background p-2 text-sm"
/>
<p className="mt-1 text-xs text-muted-foreground">
Override pricing for migrated workflows. When set, billing
will use this cost instead of the target model&apos;s
cost.
</p>
</label>
</div>
)}
</div>
)}
<form action={handleDisable} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<input type="hidden" name="is_enabled" value="false" />
{wantsMigration && selectedMigration && (
<>
<input
type="hidden"
name="migrate_to_slug"
value={selectedMigration}
/>
{migrationReason && (
<input
type="hidden"
name="migration_reason"
value={migrationReason}
/>
)}
{customCreditCost && (
<input
type="hidden"
name="custom_credit_cost"
value={customCreditCost}
/>
)}
</>
)}
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
onClick={() => {
setOpen(false);
resetState();
}}
disabled={isDisabling}
>
Cancel
</Button>
<Button
type="submit"
variant="primary"
size="small"
disabled={
isDisabling ||
(wantsMigration && !selectedMigration) ||
usageCount === null
}
>
{isDisabling
? "Disabling..."
: wantsMigration && selectedMigration
? "Disable & Migrate"
: "Disable Model"}
</Button>
</Dialog.Footer>
</form>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,223 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { updateLlmModelAction } from "../actions";
export function EditModelModal({
model,
providers,
creators,
}: {
model: LlmModel;
providers: LlmProvider[];
creators: LlmModelCreator[];
}) {
const router = useRouter();
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const cost = model.costs?.[0];
const provider = providers.find((p) => p.id === model.provider_id);
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmModelAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to update model");
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Model"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small" className="min-w-0">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Update model metadata and pricing information.
</div>
{error && (
<div className="mb-4 rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<form action={handleSubmit} className="space-y-4">
<input type="hidden" name="model_id" value={model.id} />
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Display Name
<input
required
name="display_name"
defaultValue={model.display_name}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
/>
</label>
<label className="text-sm font-medium">
Provider
<select
required
name="provider_id"
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
defaultValue={model.provider_id}
>
{providers.map((p) => (
<option key={p.id} value={p.id}>
{p.display_name} ({p.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Who hosts/serves the model
</span>
</label>
</div>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Creator
<select
name="creator_id"
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
defaultValue={model.creator_id ?? ""}
>
<option value="">No creator selected</option>
{creators.map((c) => (
<option key={c.id} value={c.id}>
{c.display_name} ({c.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Who made/trained the model (e.g., OpenAI, Meta)
</span>
</label>
</div>
<label className="text-sm font-medium">
Description
<textarea
name="description"
rows={2}
defaultValue={model.description ?? ""}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
placeholder="Optional description..."
/>
</label>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Context Window
<input
required
type="number"
name="context_window"
defaultValue={model.context_window}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={1}
/>
</label>
<label className="text-sm font-medium">
Max Output Tokens
<input
type="number"
name="max_output_tokens"
defaultValue={model.max_output_tokens ?? undefined}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={1}
/>
</label>
</div>
<div className="grid gap-4 md:grid-cols-2">
<label className="text-sm font-medium">
Credit Cost
<input
required
type="number"
name="credit_cost"
defaultValue={cost?.credit_cost ?? 0}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
min={0}
/>
<span className="text-xs text-muted-foreground">
Credits charged per run
</span>
</label>
<label className="text-sm font-medium">
Credential Provider
<select
required
name="credential_provider"
defaultValue={cost?.credential_provider ?? provider?.name ?? ""}
className="mt-1 w-full rounded border border-input bg-background p-2 text-sm"
>
<option value="" disabled>
Select provider
</option>
{providers.map((p) => (
<option key={p.id} value={p.name}>
{p.display_name} ({p.name})
</option>
))}
</select>
<span className="text-xs text-muted-foreground">
Must match a key in PROVIDER_CREDENTIALS
</span>
</label>
</div>
{/* Hidden defaults for credential_type and unit */}
<input
type="hidden"
name="credential_type"
value={
cost?.credential_type ??
provider?.default_credential_type ??
"api_key"
}
/>
<input type="hidden" name="unit" value={cost?.unit ?? "RUN"} />
<Dialog.Footer>
<Button
type="button"
variant="ghost"
size="small"
onClick={() => setOpen(false)}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Updating..." : "Update Model"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,263 +0,0 @@
"use client";
import { useState } from "react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Button } from "@/components/atoms/Button/Button";
import { updateLlmProviderAction } from "../actions";
import { useRouter } from "next/navigation";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
export function EditProviderModal({ provider }: { provider: LlmProvider }) {
const [open, setOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
async function handleSubmit(formData: FormData) {
setIsSubmitting(true);
setError(null);
try {
await updateLlmProviderAction(formData);
setOpen(false);
router.refresh();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to update provider",
);
} finally {
setIsSubmitting(false);
}
}
return (
<Dialog
title="Edit Provider"
controlled={{ isOpen: open, set: setOpen }}
styling={{ maxWidth: "768px", maxHeight: "90vh", overflowY: "auto" }}
>
<Dialog.Trigger>
<Button variant="outline" size="small">
Edit
</Button>
</Dialog.Trigger>
<Dialog.Content>
<div className="mb-4 text-sm text-muted-foreground">
Update provider configuration and capabilities.
</div>
<form action={handleSubmit} className="space-y-6">
<input type="hidden" name="provider_id" value={provider.id} />
{/* Basic Information */}
<div className="space-y-4">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Basic Information
</h3>
<p className="text-xs text-muted-foreground">
Core provider details
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="name"
className="text-sm font-medium text-foreground"
>
Provider Slug <span className="text-destructive">*</span>
</label>
<input
id="name"
required
name="name"
defaultValue={provider.name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="e.g. openai"
/>
</div>
<div className="space-y-2">
<label
htmlFor="display_name"
className="text-sm font-medium text-foreground"
>
Display Name <span className="text-destructive">*</span>
</label>
<input
id="display_name"
required
name="display_name"
defaultValue={provider.display_name}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="OpenAI"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="description"
className="text-sm font-medium text-foreground"
>
Description
</label>
<textarea
id="description"
name="description"
rows={3}
defaultValue={provider.description ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional description..."
/>
</div>
</div>
{/* Default Credentials */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Default Credentials
</h3>
<p className="text-xs text-muted-foreground">
Credential provider name that matches the key in{" "}
<code className="rounded bg-muted px-1 py-0.5 font-mono text-xs">
PROVIDER_CREDENTIALS
</code>
</p>
</div>
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<label
htmlFor="default_credential_provider"
className="text-sm font-medium text-foreground"
>
Credential Provider
</label>
<input
id="default_credential_provider"
name="default_credential_provider"
defaultValue={provider.default_credential_provider ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="openai"
/>
</div>
<div className="space-y-2">
<label
htmlFor="default_credential_id"
className="text-sm font-medium text-foreground"
>
Credential ID
</label>
<input
id="default_credential_id"
name="default_credential_id"
defaultValue={provider.default_credential_id ?? ""}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="Optional credential ID"
/>
</div>
</div>
<div className="space-y-2">
<label
htmlFor="default_credential_type"
className="text-sm font-medium text-foreground"
>
Credential Type
</label>
<input
id="default_credential_type"
name="default_credential_type"
defaultValue={provider.default_credential_type ?? "api_key"}
className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus:border-primary focus:outline-none focus:ring-2 focus:ring-primary/20"
placeholder="api_key"
/>
</div>
</div>
{/* Capabilities */}
<div className="space-y-4 border-t border-border pt-6">
<div className="space-y-1">
<h3 className="text-sm font-semibold text-foreground">
Capabilities
</h3>
<p className="text-xs text-muted-foreground">
Provider feature flags
</p>
</div>
<div className="grid gap-3 sm:grid-cols-2">
{[
{
name: "supports_tools",
label: "Supports tools",
checked: provider.supports_tools,
},
{
name: "supports_json_output",
label: "Supports JSON output",
checked: provider.supports_json_output,
},
{
name: "supports_reasoning",
label: "Supports reasoning",
checked: provider.supports_reasoning,
},
{
name: "supports_parallel_tool",
label: "Supports parallel tool calls",
checked: provider.supports_parallel_tool,
},
].map(({ name, label, checked }) => (
<div
key={name}
className="flex items-center gap-3 rounded-md border border-border bg-muted/30 px-4 py-3 transition-colors hover:bg-muted/50"
>
<input type="hidden" name={name} value="off" />
<input
id={name}
type="checkbox"
name={name}
defaultChecked={checked}
className="h-4 w-4 rounded border-input"
/>
<label
htmlFor={name}
className="text-sm font-medium text-foreground"
>
{label}
</label>
</div>
))}
</div>
</div>
{error && (
<div className="rounded-lg border border-destructive/30 bg-destructive/10 p-3 text-sm text-destructive">
{error}
</div>
)}
<Dialog.Footer>
<Button
variant="ghost"
size="small"
type="button"
onClick={() => {
setOpen(false);
setError(null);
}}
disabled={isSubmitting}
>
Cancel
</Button>
<Button
variant="primary"
size="small"
type="submit"
disabled={isSubmitting}
>
{isSubmitting ? "Saving..." : "Save Changes"}
</Button>
</Dialog.Footer>
</form>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,131 +0,0 @@
"use client";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { ErrorBoundary } from "@/components/molecules/ErrorBoundary/ErrorBoundary";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { AddProviderModal } from "./AddProviderModal";
import { AddModelModal } from "./AddModelModal";
import { AddCreatorModal } from "./AddCreatorModal";
import { ProviderList } from "./ProviderList";
import { ModelsTable } from "./ModelsTable";
import { MigrationsTable } from "./MigrationsTable";
import { CreatorsTable } from "./CreatorsTable";
import { RecommendedModelSelector } from "./RecommendedModelSelector";
interface Props {
providers: LlmProvider[];
models: LlmModel[];
migrations: LlmModelMigration[];
creators: LlmModelCreator[];
}
function AdminErrorFallback() {
return (
<div className="mx-auto max-w-xl p-6">
<ErrorCard
responseError={{
message:
"An error occurred while loading the LLM Registry. Please refresh the page.",
}}
context="llm-registry"
onRetry={() => window.location.reload()}
/>
</div>
);
}
export function LlmRegistryDashboard({
providers,
models,
migrations,
creators,
}: Props) {
return (
<ErrorBoundary fallback={<AdminErrorFallback />} context="llm-registry">
<div className="mx-auto p-6">
<div className="flex flex-col gap-6">
{/* Header */}
<div>
<h1 className="text-3xl font-bold">LLM Registry</h1>
<p className="text-muted-foreground">
Manage providers, creators, models, and credit pricing
</p>
</div>
{/* Active Migrations Section - Only show if there are migrations */}
{migrations.length > 0 && (
<div className="rounded-lg border border-primary/30 bg-primary/5 p-6 shadow-sm">
<div className="mb-4">
<h2 className="text-xl font-semibold">Active Migrations</h2>
<p className="mt-1 text-sm text-muted-foreground">
These migrations can be reverted to restore workflows to their
original model
</p>
</div>
<MigrationsTable migrations={migrations} />
</div>
)}
{/* Providers & Creators Section - Side by Side */}
<div className="grid gap-6 lg:grid-cols-2">
{/* Providers */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Providers</h2>
<p className="mt-1 text-sm text-muted-foreground">
Who hosts/serves the models
</p>
</div>
<AddProviderModal />
</div>
<ProviderList providers={providers} />
</div>
{/* Creators */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Creators</h2>
<p className="mt-1 text-sm text-muted-foreground">
Who made/trained the models
</p>
</div>
<AddCreatorModal />
</div>
<CreatorsTable creators={creators} />
</div>
</div>
{/* Models Section */}
<div className="rounded-lg border bg-card p-6 shadow-sm">
<div className="mb-4 flex items-center justify-between">
<div>
<h2 className="text-xl font-semibold">Models</h2>
<p className="mt-1 text-sm text-muted-foreground">
Toggle availability, adjust context windows, and update credit
pricing
</p>
</div>
<AddModelModal providers={providers} creators={creators} />
</div>
{/* Recommended Model Selector */}
<div className="mb-6">
<RecommendedModelSelector models={models} />
</div>
<ModelsTable
models={models}
providers={providers}
creators={creators}
/>
</div>
</div>
</div>
</ErrorBoundary>
);
}

View File

@@ -1,133 +0,0 @@
"use client";
import { useState } from "react";
import type { LlmModelMigration } from "@/app/api/__generated__/models/llmModelMigration";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { revertLlmMigrationAction } from "../actions";
export function MigrationsTable({
migrations,
}: {
migrations: LlmModelMigration[];
}) {
if (!migrations.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No active migrations. Migrations are created when you disable a model
with the &quot;Migrate existing workflows&quot; option.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Migration</TableHead>
<TableHead>Reason</TableHead>
<TableHead>Nodes Affected</TableHead>
<TableHead>Custom Cost</TableHead>
<TableHead>Created</TableHead>
<TableHead className="text-right">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{migrations.map((migration) => (
<MigrationRow key={migration.id} migration={migration} />
))}
</TableBody>
</Table>
</div>
);
}
function MigrationRow({ migration }: { migration: LlmModelMigration }) {
const [isReverting, setIsReverting] = useState(false);
const [error, setError] = useState<string | null>(null);
async function handleRevert(formData: FormData) {
setIsReverting(true);
setError(null);
try {
await revertLlmMigrationAction(formData);
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to revert migration",
);
} finally {
setIsReverting(false);
}
}
const createdDate = new Date(migration.created_at);
return (
<>
<TableRow>
<TableCell>
<div className="text-sm">
<span className="font-medium">{migration.source_model_slug}</span>
<span className="mx-2 text-muted-foreground"></span>
<span className="font-medium">{migration.target_model_slug}</span>
</div>
</TableCell>
<TableCell>
<div className="text-sm text-muted-foreground">
{migration.reason || "—"}
</div>
</TableCell>
<TableCell>
<div className="text-sm">{migration.node_count}</div>
</TableCell>
<TableCell>
<div className="text-sm">
{migration.custom_credit_cost !== null &&
migration.custom_credit_cost !== undefined
? `${migration.custom_credit_cost} credits`
: "—"}
</div>
</TableCell>
<TableCell>
<div className="text-sm text-muted-foreground">
{createdDate.toLocaleDateString()}{" "}
{createdDate.toLocaleTimeString([], {
hour: "2-digit",
minute: "2-digit",
})}
</div>
</TableCell>
<TableCell className="text-right">
<form action={handleRevert} className="inline">
<input type="hidden" name="migration_id" value={migration.id} />
<Button
type="submit"
variant="outline"
size="small"
disabled={isReverting}
>
{isReverting ? "Reverting..." : "Revert"}
</Button>
</form>
</TableCell>
</TableRow>
{error && (
<TableRow>
<TableCell colSpan={6}>
<div className="rounded border border-destructive/30 bg-destructive/10 p-2 text-sm text-destructive">
{error}
</div>
</TableCell>
</TableRow>
)}
</>
);
}

View File

@@ -1,265 +0,0 @@
"use client";
import { useState, useEffect, useRef } from "react";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import type { LlmModelCreator } from "@/app/api/__generated__/models/llmModelCreator";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import { Button } from "@/components/atoms/Button/Button";
import { toggleLlmModelAction } from "../actions";
import { DeleteModelModal } from "./DeleteModelModal";
import { DisableModelModal } from "./DisableModelModal";
import { EditModelModal } from "./EditModelModal";
import { Star, Spinner } from "@phosphor-icons/react";
import { getV2ListLlmModels } from "@/app/api/__generated__/endpoints/admin/admin";
const PAGE_SIZE = 50;
export function ModelsTable({
models: initialModels,
providers,
creators,
}: {
models: LlmModel[];
providers: LlmProvider[];
creators: LlmModelCreator[];
}) {
const [models, setModels] = useState<LlmModel[]>(initialModels);
const [currentPage, setCurrentPage] = useState(1);
const [hasMore, setHasMore] = useState(initialModels.length === PAGE_SIZE);
const [isLoading, setIsLoading] = useState(false);
const loadedPagesRef = useRef(1);
// Sync with parent when initialModels changes (e.g., after enable/disable)
// Re-fetch all loaded pages to preserve expanded state
useEffect(() => {
async function refetchAllPages() {
const pagesToLoad = loadedPagesRef.current;
if (pagesToLoad === 1) {
// Only first page loaded, just use initialModels
setModels(initialModels);
setHasMore(initialModels.length === PAGE_SIZE);
return;
}
// Re-fetch all pages we had loaded
const allModels: LlmModel[] = [...initialModels];
let lastPageHadFullResults = initialModels.length === PAGE_SIZE;
for (let page = 2; page <= pagesToLoad; page++) {
try {
const response = await getV2ListLlmModels({
page,
page_size: PAGE_SIZE,
});
if (response.status === 200) {
allModels.push(...response.data.models);
lastPageHadFullResults = response.data.models.length === PAGE_SIZE;
}
} catch (err) {
console.error(`Error refetching page ${page}:`, err);
break;
}
}
setModels(allModels);
setHasMore(lastPageHadFullResults);
}
refetchAllPages();
}, [initialModels]);
async function loadMore() {
if (isLoading) return;
setIsLoading(true);
try {
const nextPage = currentPage + 1;
const response = await getV2ListLlmModels({
page: nextPage,
page_size: PAGE_SIZE,
});
if (response.status === 200) {
setModels((prev) => [...prev, ...response.data.models]);
setCurrentPage(nextPage);
loadedPagesRef.current = nextPage;
setHasMore(response.data.models.length === PAGE_SIZE);
}
} catch (err) {
console.error("Error loading more models:", err);
} finally {
setIsLoading(false);
}
}
if (!models.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No models registered yet.
</div>
);
}
const providerLookup = new Map(
providers.map((provider) => [provider.id, provider]),
);
return (
<div>
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Model</TableHead>
<TableHead>Provider</TableHead>
<TableHead>Creator</TableHead>
<TableHead>Context Window</TableHead>
<TableHead>Max Output</TableHead>
<TableHead>Cost</TableHead>
<TableHead>Status</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{models.map((model) => {
const cost = model.costs?.[0];
const provider = providerLookup.get(model.provider_id);
return (
<TableRow
key={model.id}
className={model.is_enabled ? "" : "opacity-60"}
>
<TableCell>
<div className="font-medium">{model.display_name}</div>
<div className="text-xs text-muted-foreground">
{model.slug}
</div>
</TableCell>
<TableCell>
{provider ? (
<>
<div>{provider.display_name}</div>
<div className="text-xs text-muted-foreground">
{provider.name}
</div>
</>
) : (
model.provider_id
)}
</TableCell>
<TableCell>
{model.creator ? (
<>
<div>{model.creator.display_name}</div>
<div className="text-xs text-muted-foreground">
{model.creator.name}
</div>
</>
) : (
<span className="text-muted-foreground"></span>
)}
</TableCell>
<TableCell>{model.context_window.toLocaleString()}</TableCell>
<TableCell>
{model.max_output_tokens
? model.max_output_tokens.toLocaleString()
: "—"}
</TableCell>
<TableCell>
{cost ? (
<>
<div className="font-medium">
{cost.credit_cost} credits
</div>
<div className="text-xs text-muted-foreground">
{cost.credential_provider}
</div>
</>
) : (
"—"
)}
</TableCell>
<TableCell>
<div className="flex flex-col gap-1">
<span
className={`inline-flex rounded-full px-2.5 py-1 text-xs font-semibold ${
model.is_enabled
? "bg-primary/10 text-primary"
: "bg-muted text-muted-foreground"
}`}
>
{model.is_enabled ? "Enabled" : "Disabled"}
</span>
{model.is_recommended && (
<span className="inline-flex items-center gap-1 rounded-full bg-amber-500/10 px-2.5 py-1 text-xs font-semibold text-amber-600 dark:text-amber-400">
<Star size={12} weight="fill" />
Recommended
</span>
)}
</div>
</TableCell>
<TableCell>
<div className="flex items-center justify-end gap-2">
{model.is_enabled ? (
<DisableModelModal
model={model}
availableModels={models}
/>
) : (
<EnableModelButton modelId={model.id} />
)}
<EditModelModal
model={model}
providers={providers}
creators={creators}
/>
<DeleteModelModal
model={model}
availableModels={models}
/>
</div>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
{hasMore && (
<div className="mt-4 flex justify-center">
<Button onClick={loadMore} disabled={isLoading} variant="outline">
{isLoading ? (
<>
<Spinner className="mr-2 h-4 w-4 animate-spin" />
Loading...
</>
) : (
"Load More"
)}
</Button>
</div>
)}
</div>
);
}
function EnableModelButton({ modelId }: { modelId: string }) {
return (
<form action={toggleLlmModelAction} className="inline">
<input type="hidden" name="model_id" value={modelId} />
<input type="hidden" name="is_enabled" value="true" />
<Button type="submit" variant="outline" size="small" className="min-w-0">
Enable
</Button>
</form>
);
}

View File

@@ -1,94 +0,0 @@
"use client";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
import type { LlmProvider } from "@/app/api/__generated__/models/llmProvider";
import { DeleteProviderModal } from "./DeleteProviderModal";
import { EditProviderModal } from "./EditProviderModal";
export function ProviderList({ providers }: { providers: LlmProvider[] }) {
if (!providers.length) {
return (
<div className="rounded-lg border border-dashed border-border p-6 text-center text-sm text-muted-foreground">
No providers configured yet.
</div>
);
}
return (
<div className="rounded-lg border">
<Table>
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>Display Name</TableHead>
<TableHead>Default Credential</TableHead>
<TableHead>Capabilities</TableHead>
<TableHead>Models</TableHead>
<TableHead className="w-[100px]">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{providers.map((provider) => (
<TableRow key={provider.id}>
<TableCell className="font-medium">{provider.name}</TableCell>
<TableCell>{provider.display_name}</TableCell>
<TableCell>
{provider.default_credential_provider
? `${provider.default_credential_provider} (${provider.default_credential_id ?? "id?"})`
: "—"}
</TableCell>
<TableCell className="text-sm text-muted-foreground">
<div className="flex flex-wrap gap-2">
{provider.supports_tools && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
Tools
</span>
)}
{provider.supports_json_output && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
JSON
</span>
)}
{provider.supports_reasoning && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
Reasoning
</span>
)}
{provider.supports_parallel_tool && (
<span className="rounded bg-muted px-2 py-0.5 text-xs">
Parallel Tools
</span>
)}
</div>
</TableCell>
<TableCell className="text-sm">
<span
className={
(provider.models?.length ?? 0) > 0
? "text-foreground"
: "text-muted-foreground"
}
>
{provider.models?.length ?? 0}
</span>
</TableCell>
<TableCell>
<div className="flex gap-2">
<EditProviderModal provider={provider} />
<DeleteProviderModal provider={provider} />
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
);
}

View File

@@ -1,87 +0,0 @@
"use client";
import { useState } from "react";
import { useRouter } from "next/navigation";
import type { LlmModel } from "@/app/api/__generated__/models/llmModel";
import { Button } from "@/components/atoms/Button/Button";
import { setRecommendedModelAction } from "../actions";
import { Star } from "@phosphor-icons/react";
export function RecommendedModelSelector({ models }: { models: LlmModel[] }) {
const router = useRouter();
const enabledModels = models.filter((m) => m.is_enabled);
const currentRecommended = models.find((m) => m.is_recommended);
const [selectedModelId, setSelectedModelId] = useState<string>(
currentRecommended?.id || "",
);
const [isSaving, setIsSaving] = useState(false);
const [error, setError] = useState<string | null>(null);
const hasChanges = selectedModelId !== (currentRecommended?.id || "");
async function handleSave() {
if (!selectedModelId) return;
setIsSaving(true);
setError(null);
try {
const formData = new FormData();
formData.set("model_id", selectedModelId);
await setRecommendedModelAction(formData);
router.refresh();
} catch (err) {
setError(err instanceof Error ? err.message : "Failed to save");
} finally {
setIsSaving(false);
}
}
return (
<div className="rounded-lg border border-border bg-card p-4">
<div className="mb-3 flex items-center gap-2">
<Star size={20} weight="fill" className="text-amber-500" />
<h3 className="text-sm font-semibold">Recommended Model</h3>
</div>
<p className="mb-3 text-xs text-muted-foreground">
The recommended model is shown as the default suggestion in model
selection dropdowns throughout the platform.
</p>
<div className="flex items-center gap-3">
<select
value={selectedModelId}
onChange={(e) => setSelectedModelId(e.target.value)}
className="flex-1 rounded-md border border-input bg-background px-3 py-2 text-sm"
disabled={isSaving}
>
<option value="">-- Select a model --</option>
{enabledModels.map((model) => (
<option key={model.id} value={model.id}>
{model.display_name} ({model.slug})
</option>
))}
</select>
<Button
type="button"
variant="primary"
size="small"
onClick={handleSave}
disabled={!hasChanges || !selectedModelId || isSaving}
>
{isSaving ? "Saving..." : "Save"}
</Button>
</div>
{error && <p className="mt-2 text-xs text-destructive">{error}</p>}
{currentRecommended && !hasChanges && (
<p className="mt-2 text-xs text-muted-foreground">
Currently set to:{" "}
<span className="font-medium">{currentRecommended.display_name}</span>
</p>
)}
</div>
);
}

View File

@@ -1,46 +0,0 @@
/**
* Server-side data fetching for LLM Registry page.
*/
import {
fetchLlmCreators,
fetchLlmMigrations,
fetchLlmModels,
fetchLlmProviders,
} from "./actions";
export async function getLlmRegistryPageData() {
// Fetch providers and models (required)
const [providersResponse, modelsResponse] = await Promise.all([
fetchLlmProviders(),
fetchLlmModels(),
]);
// Fetch migrations separately with fallback (table might not exist yet)
let migrations: Awaited<ReturnType<typeof fetchLlmMigrations>>["migrations"] =
[];
try {
const migrationsResponse = await fetchLlmMigrations(false);
migrations = migrationsResponse.migrations;
} catch {
// Migrations table might not exist yet - that's ok, just show empty list
console.warn("Could not fetch migrations - table may not exist yet");
}
// Fetch creators separately with fallback (table might not exist yet)
let creators: Awaited<ReturnType<typeof fetchLlmCreators>>["creators"] = [];
try {
const creatorsResponse = await fetchLlmCreators();
creators = creatorsResponse.creators;
} catch {
// Creators table might not exist yet - that's ok, just show empty list
console.warn("Could not fetch creators - table may not exist yet");
}
return {
providers: providersResponse.providers,
models: modelsResponse.models,
migrations,
creators,
};
}

View File

@@ -1,14 +0,0 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { getLlmRegistryPageData } from "./getLlmRegistryPage";
import { LlmRegistryDashboard } from "./components/LlmRegistryDashboard";
async function LlmRegistryPage() {
const data = await getLlmRegistryPageData();
return <LlmRegistryDashboard {...data} />;
}
export default async function AdminLlmRegistryPage() {
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedLlmRegistryPage = await withAdminAccess(LlmRegistryPage);
return <ProtectedLlmRegistryPage />;
}

View File

@@ -7,9 +7,8 @@ import { BlockCategoryResponse } from "@/app/api/__generated__/models/blockCateg
import { BlockResponse } from "@/app/api/__generated__/models/blockResponse";
import * as Sentry from "@sentry/nextjs";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { useState, useEffect } from "react";
import { useState } from "react";
import { useToast } from "@/components/molecules/Toast/use-toast";
import BackendApi from "@/lib/autogpt-server-api";
export const useAllBlockContent = () => {
const { toast } = useToast();
@@ -94,32 +93,6 @@ export const useAllBlockContent = () => {
const isErrorOnLoadingMore = (categoryName: string) =>
errorLoadingCategories.has(categoryName);
// Listen for LLM registry refresh notifications
useEffect(() => {
const api = new BackendApi();
const queryClient = getQueryClient();
const handleNotification = (notification: any) => {
if (
notification?.type === "LLM_REGISTRY_REFRESH" ||
notification?.event === "registry_updated"
) {
// Invalidate all block-related queries to force refresh
const categoriesQueryKey = getGetV2GetBuilderBlockCategoriesQueryKey();
queryClient.invalidateQueries({ queryKey: categoriesQueryKey });
}
};
const unsubscribe = api.onWebSocketMessage(
"notification",
handleNotification,
);
return () => {
unsubscribe();
};
}, []);
return {
data,
isLoading,

View File

@@ -610,11 +610,8 @@ const NodeOneOfDiscriminatorField: FC<{
return oneOfVariants
.map((variant) => {
const discProperty = variant.properties?.[discriminatorProperty];
const variantDiscValue =
discProperty && "const" in discProperty
? (discProperty.const as string)
: undefined; // NOTE: can discriminators only be strings?
const variantDiscValue = variant.properties?.[discriminatorProperty]
?.const as string; // NOTE: can discriminators only be strings?
return {
value: variantDiscValue,
@@ -1127,47 +1124,9 @@ const NodeStringInput: FC<{
displayName,
}) => {
value ||= schema.default || "";
// Check if we have options with labels (e.g., LLM model picker)
const hasOptions = schema.options && schema.options.length > 0;
const hasEnum = schema.enum && schema.enum.length > 0;
// Helper to get display label for a value
const getDisplayLabel = (val: string) => {
if (hasOptions) {
const option = schema.options!.find((opt) => opt.value === val);
return option?.label || beautifyString(val);
}
return beautifyString(val);
};
return (
<div className={className}>
{hasOptions ? (
// Render options with proper labels (used by LLM model picker)
<Select
defaultValue={value}
onValueChange={(newValue) => handleInputChange(selfKey, newValue)}
>
<SelectTrigger>
<SelectValue placeholder={schema.placeholder || displayName}>
{value ? getDisplayLabel(value) : undefined}
</SelectValue>
</SelectTrigger>
<SelectContent className="nodrag">
{schema.options!.map((option, index) => (
<SelectItem
key={index}
value={option.value}
title={option.description}
>
{option.label || beautifyString(option.value)}
</SelectItem>
))}
</SelectContent>
</Select>
) : hasEnum ? (
// Fallback to enum with beautified strings
{schema.enum && schema.enum.length > 0 ? (
<Select
defaultValue={value}
onValueChange={(newValue) => handleInputChange(selfKey, newValue)}
@@ -1176,8 +1135,8 @@ const NodeStringInput: FC<{
<SelectValue placeholder={schema.placeholder || displayName} />
</SelectTrigger>
<SelectContent className="nodrag">
{schema
.enum!.filter((option) => option)
{schema.enum
.filter((option) => option)
.map((option, index) => (
<SelectItem key={index} value={option}>
{beautifyString(option)}

View File

@@ -20,6 +20,7 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
import { GenericTool } from "../../tools/GenericTool/GenericTool";
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
// ---------------------------------------------------------------------------
@@ -255,6 +256,16 @@ export const ChatMessagesContainer = ({
/>
);
default:
// Render a generic tool indicator for SDK built-in
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
if (part.type.startsWith("tool-")) {
return (
<GenericTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
}
return null;
}
})}

View File

@@ -0,0 +1,63 @@
"use client";
import { ToolUIPart } from "ai";
import { GearIcon } from "@phosphor-icons/react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
interface Props {
part: ToolUIPart;
}
function extractToolName(part: ToolUIPart): string {
// ToolUIPart.type is "tool-{name}", extract the name portion.
return part.type.replace(/^tool-/, "");
}
function formatToolName(name: string): string {
// "search_docs" → "Search docs", "Read" → "Read"
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
}
function getAnimationText(part: ToolUIPart): string {
const label = formatToolName(extractToolName(part));
switch (part.state) {
case "input-streaming":
case "input-available":
return `Running ${label}`;
case "output-available":
return `${label} completed`;
case "output-error":
return `${label} failed`;
default:
return `Running ${label}`;
}
}
export function GenericTool({ part }: Props) {
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<GearIcon
size={14}
weight="regular"
className={
isError
? "text-red-500"
: isStreaming
? "animate-spin text-neutral-500"
: "text-neutral-400"
}
/>
<MorphingTextAnimation
text={getAnimationText(part)}
className={isError ? "text-red-500" : undefined}
/>
</div>
</div>
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,123 +0,0 @@
import * as React from "react";
import { cn } from "@/lib/utils";
const Table = React.forwardRef<
HTMLTableElement,
React.HTMLAttributes<HTMLTableElement>
>(({ className, ...props }, ref) => (
<div className="relative w-full overflow-auto">
<table
ref={ref}
className={cn("w-full caption-bottom text-sm", className)}
{...props}
/>
</div>
));
Table.displayName = "Table";
const TableHeader = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<thead ref={ref} className={cn("[&_tr]:border-b", className)} {...props} />
));
TableHeader.displayName = "TableHeader";
const TableBody = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<tbody
ref={ref}
className={cn("[&_tr:last-child]:border-0", className)}
{...props}
/>
));
TableBody.displayName = "TableBody";
const TableFooter = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<tfoot
ref={ref}
className={cn(
"border-t bg-neutral-100/50 font-medium dark:bg-neutral-800/50 [&>tr]:last:border-b-0",
className,
)}
{...props}
/>
));
TableFooter.displayName = "TableFooter";
const TableRow = React.forwardRef<
HTMLTableRowElement,
React.HTMLAttributes<HTMLTableRowElement>
>(({ className, ...props }, ref) => (
<tr
ref={ref}
className={cn(
"border-b transition-colors data-[state=selected]:bg-neutral-100 hover:bg-neutral-100/50 dark:data-[state=selected]:bg-neutral-800 dark:hover:bg-neutral-800/50",
className,
)}
{...props}
/>
));
TableRow.displayName = "TableRow";
const TableHead = React.forwardRef<
HTMLTableCellElement,
React.ThHTMLAttributes<HTMLTableCellElement>
>(({ className, ...props }, ref) => (
<th
ref={ref}
className={cn(
"h-10 px-2 text-left align-middle font-medium text-neutral-500 dark:text-neutral-400 [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]",
className,
)}
{...props}
/>
));
TableHead.displayName = "TableHead";
const TableCell = React.forwardRef<
HTMLTableCellElement,
React.TdHTMLAttributes<HTMLTableCellElement>
>(({ className, ...props }, ref) => (
<td
ref={ref}
className={cn(
"p-2 align-middle [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]",
className,
)}
{...props}
/>
));
TableCell.displayName = "TableCell";
const TableCaption = React.forwardRef<
HTMLTableCaptionElement,
React.HTMLAttributes<HTMLTableCaptionElement>
>(({ className, ...props }, ref) => (
<caption
ref={ref}
className={cn(
"mt-4 text-sm text-neutral-500 dark:text-neutral-400",
className,
)}
{...props}
/>
));
TableCaption.displayName = "TableCaption";
export {
Table,
TableHeader,
TableBody,
TableFooter,
TableHead,
TableRow,
TableCell,
TableCaption,
};

View File

@@ -6,7 +6,7 @@ import {
TableHead,
TableHeader,
TableRow,
} from "@/components/atoms/Table/Table";
} from "@/components/__legacy__/ui/table";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Text } from "@/components/atoms/Text/Text";

View File

@@ -1,19 +1,8 @@
import { RJSFSchema } from "@rjsf/utils";
/**
* Options type for fields with label/value pairs (e.g., LLM model picker)
*/
type SchemaOption = {
label: string;
value: string;
group?: string;
description?: string;
};
/**
* Pre-processes the input schema to ensure all properties have a type defined.
* If a property doesn't have a type, it assigns a union of all supported JSON Schema types.
* Also converts custom 'options' array to RJSF's enum/enumNames format.
*/
export function preprocessInputSchema(schema: RJSFSchema): RJSFSchema {
@@ -31,20 +20,6 @@ export function preprocessInputSchema(schema: RJSFSchema): RJSFSchema {
if (property && typeof property === "object") {
const processedProperty = { ...property };
// Convert custom 'options' array to RJSF's enum/enumNames format
// This enables proper label display for dropdowns like the LLM model picker
if (
(processedProperty as any).options &&
Array.isArray((processedProperty as any).options) &&
(processedProperty as any).options.length > 0
) {
const options = (processedProperty as any).options as SchemaOption[];
processedProperty.enum = options.map((opt) => opt.value);
(processedProperty as any).enumNames = options.map(
(opt) => opt.label,
);
}
// Only add type if no type is defined AND no anyOf/oneOf/allOf is present
if (
!processedProperty.type &&

View File

@@ -77,45 +77,17 @@ export default function useAgentGraph(
// Load available blocks & flows (stable - only loads once)
useEffect(() => {
const loadBlocks = () => {
api
.getBlocks()
.then((blocks) => {
setAllBlocks(blocks);
})
.catch();
};
api
.getBlocks()
.then((blocks) => {
setAllBlocks(blocks);
})
.catch();
const loadFlows = () => {
api
.listGraphs()
.then((flows) => setAvailableFlows(flows))
.catch();
};
// Initial load
loadBlocks();
loadFlows();
// Listen for LLM registry refresh notifications to reload blocks
const deregisterRegistryRefresh = api.onWebSocketMessage(
"notification",
(notification) => {
if (
notification?.type === "LLM_REGISTRY_REFRESH" ||
notification?.event === "registry_updated"
) {
console.log(
"Received LLM registry refresh notification, reloading blocks...",
);
loadBlocks();
}
},
);
return () => {
deregisterRegistryRefresh();
};
api
.listGraphs()
.then((flows) => setAvailableFlows(flows))
.catch();
}, [api]);
// Subscribe to execution events

View File

@@ -186,7 +186,6 @@ export type BlockIOStringSubSchema = BlockIOSubSchemaMeta & {
default?: string;
format?: string;
maxLength?: number;
options?: { value: string; label: string; description?: string }[];
};
export type BlockIONumberSubSchema = BlockIOSubSchemaMeta & {

View File

@@ -285,20 +285,17 @@ export function fillObjectDefaultsFromSchema(
// Apply simple default values
obj[key] ??= propertySchema.default;
} else if (
"type" in propertySchema &&
propertySchema.type === "object" &&
"properties" in propertySchema
) {
// Recursively fill defaults for nested objects
obj[key] = fillObjectDefaultsFromSchema(obj[key] ?? {}, propertySchema);
} else if ("type" in propertySchema && propertySchema.type === "array") {
} else if (propertySchema.type === "array") {
obj[key] ??= [];
// If the array items are objects, fill their defaults as well
if (
Array.isArray(obj[key]) &&
propertySchema.items &&
"type" in propertySchema.items &&
propertySchema.items.type === "object" &&
propertySchema.items?.type === "object" &&
"properties" in propertySchema.items
) {
for (const item of obj[key]) {