Compare commits

..

9 Commits

Author SHA1 Message Date
Zamil Majdy
b071cd42f5 Merge remote-tracking branch 'origin/dev' into feat/agent-gen-smart-decision 2026-03-17 06:17:15 +07:00
Zamil Majdy
210a69d33e fix(copilot): update stale docstring for SDM fixer default
The docstring still referenced -1 (infinite) after changing the default
to 10 (bounded).
2026-03-17 00:19:39 +07:00
Zamil Majdy
2ddddc0257 fix(copilot): use bounded default for agent_mode_max_iterations
Change fixer default from -1 (infinite) to 10 (bounded) for safety.
Update guide to let LLM choose iteration count based on task complexity:
1 for single-step, 3-10 for multi-step, -1 for open-ended orchestration.
2026-03-17 00:17:56 +07:00
Zamil Majdy
59f05ed23c fix(copilot): guard SDM input_default type and reject invalid iterations
- Handle input_default=None by replacing with empty dict in fixer
- Reject agent_mode_max_iterations < -1 in validator (only -1 or
  positive values are valid)
- Add tests for both edge cases
2026-03-16 22:41:40 +07:00
Zamil Majdy
09fd30e14f fix(copilot): reject agent_mode_max_iterations=0 in SDM validator
Traditional mode (0) requires complex external conversation-history
loop wiring that the agent generator does not produce. Validate that
generated SDM nodes use agent mode (-1 or positive) only.
2026-03-16 22:25:43 +07:00
Zamil Majdy
ef6118c640 fix(copilot): treat explicit null SDM fields as missing defaults
The fixer now treats `None` values in SDM input_default as missing,
overwriting them with the correct defaults. This handles cases where
the LLM generates `null` for fields it doesn't know the value of.
2026-03-16 22:21:39 +07:00
Zamil Majdy
5d527eab85 fix(copilot): address review comments on SDM validator and tests
- Filter out AgentInput/OutputBlocks from SDM tool link check (interface
  blocks aren't real tools)
- Add test for interface-block-only links failing validation
- Add test for explicit None values being treated as missing in fixer
- Assert full graph validity in e2e pipeline test
2026-03-16 22:19:45 +07:00
Zamil Majdy
f0af149f16 fix(copilot): keep SmartDecisionMakerBlock excluded from CoPilot standalone
Address review feedback:
- Revert find_block.py exclusion removal — SDM requires graph context
  and would crash if run via run_block (missing execution_processor)
- The guide hardcodes the block ID so agent generation still works
- Add warning against agent_mode_max_iterations=0 in guide
- Hoist _DEFAULTS to module-level _SDM_DEFAULTS constant in fixer.py
2026-03-16 21:58:57 +07:00
Zamil Majdy
5ad71099ac feat(copilot): enable SmartDecisionMakerBlock in agent generator
Allow the agent generator to create orchestrator agents that use
SmartDecisionMakerBlock with agent mode to autonomously decide which
tools or sub-agents to call in a loop until the task is complete.

Changes:
- Remove SmartDecisionMakerBlock from COPILOT_EXCLUDED_BLOCK_IDS
- Add SMART_DECISION_MAKER_BLOCK_ID constant to helpers
- Add fixer to populate agent-mode defaults (max_iterations=-1, etc.)
- Add validator to ensure downstream tool blocks are connected
- Document SmartDecisionMakerBlock usage in agent_generation_guide.md
- Add 18 tests covering fixer, validator, and e2e pipeline
2026-03-16 21:51:24 +07:00
20 changed files with 848 additions and 1004 deletions

View File

@@ -1,162 +0,0 @@
"""Integration credential lookup with per-process TTL cache.
Provides token retrieval for connected integrations so that copilot tools
(e.g. bash_exec) can inject auth tokens into the execution environment without
hitting the database on every command.
Cache semantics (handled automatically by TTLCache):
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
for users who have credentials and are running many bash commands.
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
on every E2B command for users who haven't connected an account yet, while
still picking up a newly-connected account within one minute.
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
least-recently-used entry when the limit is reached.
Multi-worker note: both caches are in-process only. Each worker/replica
maintains its own independent cache, so a credential fetch may be duplicated
across processes. This is acceptable for the current goal (reduce DB hits per
session per-process), but if cache efficiency across replicas becomes important
a shared cache (e.g. Redis) should be used instead.
"""
import logging
from typing import cast
from cachetools import TTLCache
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
register_creds_changed_hook,
)
logger = logging.getLogger(__name__)
# Maps provider slug → env var names to inject when the provider is connected.
# Add new providers here when adding integration support.
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
# must be updated when adding a new provider.
PROVIDER_ENV_VARS: dict[str, list[str]] = {
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
}
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
_CACHE_MAX_SIZE = 10_000
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
# because all callers (get_provider_token, invalidate_user_provider_cache) run
# exclusively on the asyncio event loop. There are no await points between a
# cache read and its corresponding write within any function, so no concurrent
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
# this path, a threading.RLock should be wrapped around these caches.
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
)
# Separate cache for "no credentials" results with a shorter TTL.
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
"""Remove the cached entry for *user_id*/*provider* from both caches.
Call this after storing new credentials so that the next
``get_provider_token()`` call performs a fresh DB lookup instead of
serving a stale TTL-cached result.
"""
key = (user_id, provider)
_token_cache.pop(key, None)
_null_cache.pop(key, None)
# Register this module's cache-bust function with the credentials manager so
# that any create/update/delete operation immediately evicts stale cache
# entries. This avoids a lazy import inside creds_manager and eliminates the
# circular-import risk.
register_creds_changed_hook(invalidate_user_provider_cache)
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
# on every cache-miss call to get_provider_token().
_manager = IntegrationCredentialsManager()
async def get_provider_token(user_id: str, provider: str) -> str | None:
"""Return the user's access token for *provider*, or ``None`` if not connected.
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
command for users who haven't connected yet, while still picking up a
newly-connected account within one minute.
"""
cache_key = (user_id, provider)
if cache_key in _null_cache:
return None
if cached := _token_cache.get(cache_key):
return cached
manager = _manager
try:
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
except Exception:
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
return None
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
# full git access, while a public-data-only token lacks push/pull permission.
# lock=False — background injection; not worth a distributed lock acquisition.
oauth2_creds = sorted(
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
fresh = await manager.refresh_if_needed(
user_id, cast(OAuth2Credentials, creds), lock=False
)
token = fresh.access_token.get_secret_value()
except Exception:
logger.warning(
"Failed to refresh %s OAuth token for user %s; "
"falling back to potentially stale token",
provider,
user_id,
)
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
_token_cache[cache_key] = token
return token
# Pass 2: fall back to API key (no expiry, no refresh needed).
for creds in creds_list:
if creds.type == "api_key":
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
_token_cache[cache_key] = token
return token
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
return None
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
"""Return env vars for all providers the user has connected.
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
Only providers with a stored credential contribute entries.
"""
env: dict[str, str] = {}
for provider, var_names in PROVIDER_ENV_VARS.items():
token = await get_provider_token(user_id, provider)
if token:
for var in var_names:
env[var] = token
return env

View File

@@ -1,193 +0,0 @@
"""Tests for integration_creds — TTL cache and token lookup paths."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.copilot.integration_creds import (
_NULL_CACHE_TTL,
_TOKEN_CACHE_TTL,
PROVIDER_ENV_VARS,
_null_cache,
_token_cache,
get_integration_env_vars,
get_provider_token,
invalidate_user_provider_cache,
)
from backend.data.model import APIKeyCredentials, OAuth2Credentials
_USER = "user-integration-creds-test"
_PROVIDER = "github"
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
return APIKeyCredentials(
id="creds-api-key",
provider=_PROVIDER,
api_key=SecretStr(key),
title="Test API Key",
expires_at=None,
)
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
return OAuth2Credentials(
id="creds-oauth2",
provider=_PROVIDER,
title="Test OAuth",
access_token=SecretStr(token),
refresh_token=SecretStr("test-refresh"),
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
)
@pytest.fixture(autouse=True)
def clear_caches():
"""Ensure clean caches before and after every test."""
_token_cache.clear()
_null_cache.clear()
yield
_token_cache.clear()
_null_cache.clear()
class TestInvalidateUserProviderCache:
def test_removes_token_entry(self):
key = (_USER, _PROVIDER)
_token_cache[key] = "tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _token_cache
def test_removes_null_entry(self):
key = (_USER, _PROVIDER)
_null_cache[key] = True
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _null_cache
def test_noop_when_key_not_cached(self):
# Should not raise even when there is no cache entry.
invalidate_user_provider_cache("no-such-user", _PROVIDER)
def test_only_removes_targeted_key(self):
other_key = ("other-user", _PROVIDER)
_token_cache[other_key] = "other-tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert other_key in _token_cache
class TestGetProviderToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_cached_token_without_db_hit(self):
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "cached-tok"
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_none_for_null_cached_provider(self):
_null_cache[(_USER, _PROVIDER)] = True
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_api_key_creds_returned_and_cached(self):
api_creds = _make_api_key_creds("my-api-key")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "my-api-key"
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_preferred_over_api_key(self):
oauth_creds = _make_oauth2_creds("oauth-tok")
api_creds = _make_api_key_creds("api-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
return_value=[api_creds, oauth_creds]
)
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "stale-oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
assert _null_cache.get((_USER, _PROVIDER)) is True
@pytest.mark.asyncio(loop_scope="session")
async def test_db_exception_returns_none_without_caching(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
side_effect=RuntimeError("db down")
)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
# DB errors are not cached — next call will retry
assert (_USER, _PROVIDER) not in _token_cache
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
"""Verify the TTL constants are set correctly for each cache."""
assert _null_cache.ttl == _NULL_CACHE_TTL
assert _token_cache.ttl == _TOKEN_CACHE_TTL
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):
_token_cache[(_USER, "github")] = "gh-tok"
result = await get_integration_env_vars(_USER)
for var in PROVIDER_ENV_VARS["github"]:
assert result[var] == "gh-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_dict_when_no_credentials(self):
_null_cache[(_USER, "github")] = True
result = await get_integration_env_vars(_USER)
assert result == {}

View File

@@ -95,25 +95,6 @@ Example — committing an image file to GitHub:
All tasks must run in the foreground.
"""
# E2B-only notes — E2B has full internet access so gh CLI works there.
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
_E2B_TOOL_NOTES = """
### GitHub CLI (`gh`) and git
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.
- If the token changes mid-session (e.g. user reconnects with a new token),
run `gh auth setup-git` to re-register the credential helper.
- If `gh` or `git` fails with an authentication error (e.g. "authentication
required", "could not read Username", or exit code 128), call
`connect_integration(provider="github")` to surface the GitHub credentials
setup card so the user can connect their account. Once connected, retry
the operation.
- For operations that need broader access (e.g. private org repos, GitHub
Actions), pass the required scopes: e.g.
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
@@ -124,7 +105,6 @@ def _build_storage_supplement(
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
extra_notes: str = "",
) -> str:
"""Build storage/filesystem supplement for a specific environment.
@@ -139,7 +119,6 @@ def _build_storage_supplement(
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
extra_notes: Environment-specific notes appended after shared notes
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
@@ -173,16 +152,12 @@ def _build_storage_supplement(
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
{_SHARED_TOOL_NOTES}{extra_notes}"""
{_SHARED_TOOL_NOTES}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns).
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
like gh will not work — no integration env-var notes are included.
"""
"""Local ephemeral storage (files lost between turns)."""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
@@ -200,11 +175,7 @@ def _get_local_storage_supplement(cwd: str) -> str:
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session).
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
injected per command in bash_exec — include the CLI guidance notes.
"""
"""Cloud persistent sandbox (files survive across turns in session)."""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
@@ -219,7 +190,6 @@ def _get_cloud_sandbox_supplement() -> str:
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
extra_notes=_E2B_TOOL_NOTES,
)

View File

@@ -143,6 +143,48 @@ To use an MCP (Model Context Protocol) tool as a node in the agent:
tool_arguments.
6. Output: `result` (the tool's return value) and `error` (error message)
### Using SmartDecisionMakerBlock (AI Orchestrator with Agent Mode)
To create an agent where AI autonomously decides which tools or sub-agents to
call in a loop until the task is complete:
1. Create a `SmartDecisionMakerBlock` node
(ID: `3b191d9f-356f-482d-8238-ba04b6d18381`)
2. Set `input_default`:
- `agent_mode_max_iterations`: Choose based on task complexity:
- `1` for single-step tool calls (AI picks one tool, calls it, done)
- `3``10` for multi-step tasks (AI calls tools iteratively)
- `-1` for open-ended orchestration (AI loops until it decides it's done)
Do NOT use `0` (traditional mode) — it requires complex external
conversation-history loop wiring that the agent generator does not
produce.
- `conversation_compaction`: `true` (recommended to avoid context overflow)
- Optional: `sys_prompt` for extra LLM context about how to orchestrate
3. Wire the `prompt` input from an `AgentInputBlock` (the user's task)
4. Create downstream tool blocks — regular blocks **or** `AgentExecutorBlock`
nodes that call sub-agents
5. Link each tool to the SmartDecisionMaker: set `source_name: "tools"` on
the SmartDecisionMaker side and `sink_name: <input_field>` on each tool
block's input. Create one link per input field the tool needs.
6. Wire the `finished` output to an `AgentOutputBlock` for the final result
7. Credentials (LLM API key) are configured by the user in the platform UI
after saving — do NOT require them upfront
**Example — Orchestrator calling two sub-agents:**
- Node 1: `AgentInputBlock` (input_default: `{"name": "task"}`)
- Node 2: `SmartDecisionMakerBlock` (input_default:
`{"agent_mode_max_iterations": 10, "conversation_compaction": true}`)
- Node 3: `AgentExecutorBlock` (sub-agent A — set `graph_id`, `graph_version`,
`input_schema`, `output_schema` from library agent)
- Node 4: `AgentExecutorBlock` (sub-agent B — same pattern)
- Node 5: `AgentOutputBlock` (input_default: `{"name": "result"}`)
- Links:
- Input→SDM: `source_name: "result"`, `sink_name: "prompt"`
- SDM→Agent A (per input field): `source_name: "tools"`,
`sink_name: "<agent_a_input_field>"`
- SDM→Agent B (per input field): `source_name: "tools"`,
`sink_name: "<agent_b_input_field>"`
- SDM→Output: `source_name: "finished"`, `sink_name: "value"`
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:

View File

@@ -769,7 +769,7 @@ async def stream_chat_completion_sdk(
)
return None
try:
sandbox = await get_or_create_sandbox(
return await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
@@ -783,9 +783,7 @@ async def stream_chat_completion_sdk(
e2b_err,
exc_info=True,
)
return None
return sandbox
return None
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""

View File

@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .connect_integration import ConnectIntegrationTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
@@ -85,7 +84,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
"connect_integration": ConnectIntegrationTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),

View File

@@ -7,6 +7,7 @@ from typing import Any
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
SMART_DECISION_MAKER_BLOCK_ID,
AgentDict,
are_types_compatible,
generate_uuid,
@@ -30,6 +31,14 @@ _GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
_GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
# Defaults applied to SmartDecisionMakerBlock nodes by the fixer.
_SDM_DEFAULTS: dict[str, object] = {
"agent_mode_max_iterations": 10,
"conversation_compaction": True,
"retry": 3,
"multiple_tool_calls": False,
}
class AgentFixer:
"""
@@ -1630,6 +1639,43 @@ class AgentFixer:
return agent
def fix_smart_decision_maker_blocks(self, agent: AgentDict) -> AgentDict:
"""Fix SmartDecisionMakerBlock nodes to ensure agent-mode defaults.
Ensures:
1. ``agent_mode_max_iterations`` defaults to ``10`` (bounded agent mode)
2. ``conversation_compaction`` defaults to ``True``
3. ``retry`` defaults to ``3``
4. ``multiple_tool_calls`` defaults to ``False``
Args:
agent: The agent dictionary to fix
Returns:
The fixed agent dictionary
"""
nodes = agent.get("nodes", [])
for node in nodes:
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
continue
node_id = node.get("id", "unknown")
input_default = node.get("input_default")
if not isinstance(input_default, dict):
input_default = {}
node["input_default"] = input_default
for field, default_value in _SDM_DEFAULTS.items():
if field not in input_default or input_default[field] is None:
input_default[field] = default_value
self.add_fix_log(
f"SmartDecisionMakerBlock {node_id}: "
f"Set {field}={default_value!r}"
)
return agent
def fix_dynamic_block_sink_names(self, agent: AgentDict) -> AgentDict:
"""Fix links that use _#_ notation for dynamic block sink names.
@@ -1717,6 +1763,9 @@ class AgentFixer:
# Apply fixes for MCPToolBlock nodes
agent = self.fix_mcp_tool_blocks(agent)
# Apply fixes for SmartDecisionMakerBlock nodes (agent-mode defaults)
agent = self.fix_smart_decision_maker_blocks(agent)
# Apply fixes for AgentExecutorBlock nodes (sub-agents)
if library_agents:
agent = self.fix_agent_executor_blocks(agent, library_agents)

View File

@@ -12,6 +12,7 @@ __all__ = [
"AGENT_OUTPUT_BLOCK_ID",
"AgentDict",
"MCP_TOOL_BLOCK_ID",
"SMART_DECISION_MAKER_BLOCK_ID",
"UUID_REGEX",
"are_types_compatible",
"generate_uuid",
@@ -33,6 +34,7 @@ UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
SMART_DECISION_MAKER_BLOCK_ID = "3b191d9f-356f-482d-8238-ba04b6d18381"
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"

View File

@@ -10,6 +10,7 @@ from .helpers import (
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
SMART_DECISION_MAKER_BLOCK_ID,
AgentDict,
are_types_compatible,
get_defined_property_type,
@@ -809,6 +810,73 @@ class AgentValidator:
return valid
def validate_smart_decision_maker_blocks(self, agent: AgentDict) -> bool:
"""Validate that SmartDecisionMakerBlock nodes have downstream tools.
Checks that each SmartDecisionMakerBlock node has at least one link
with ``source_name == "tools"`` connecting to a downstream block.
Without tools, the block has nothing to call and will error at runtime.
Returns True if all SmartDecisionMakerBlock nodes are valid.
"""
valid = True
nodes = agent.get("nodes", [])
links = agent.get("links", [])
node_lookup = {node.get("id", ""): node for node in nodes}
non_tool_block_ids = {AGENT_INPUT_BLOCK_ID, AGENT_OUTPUT_BLOCK_ID}
for node in nodes:
if node.get("block_id") != SMART_DECISION_MAKER_BLOCK_ID:
continue
node_id = node.get("id", "unknown")
customized_name = (node.get("metadata") or {}).get(
"customized_name", node_id
)
# Warn if agent_mode_max_iterations is 0 (traditional mode) —
# requires complex external conversation-history loop wiring
# that the agent generator does not produce.
input_default = node.get("input_default", {})
max_iter = input_default.get("agent_mode_max_iterations")
if isinstance(max_iter, int) and max_iter < -1:
self.add_error(
f"SmartDecisionMakerBlock node '{customized_name}' "
f"({node_id}) has invalid "
f"agent_mode_max_iterations={max_iter}. "
f"Use -1 for infinite or a positive number for "
f"bounded iterations."
)
valid = False
elif max_iter == 0:
self.add_error(
f"SmartDecisionMakerBlock node '{customized_name}' "
f"({node_id}) has agent_mode_max_iterations=0 "
f"(traditional mode). The agent generator only supports "
f"agent mode (set to -1 for infinite or a positive "
f"number for bounded iterations)."
)
valid = False
has_tools = any(
link.get("source_id") == node_id
and link.get("source_name") == "tools"
and node_lookup.get(link.get("sink_id", ""), {}).get("block_id")
not in non_tool_block_ids
for link in links
)
if not has_tools:
self.add_error(
f"SmartDecisionMakerBlock node '{customized_name}' "
f"({node_id}) has no downstream tool blocks connected. "
f"Connect at least one block to its 'tools' output so "
f"the AI has tools to call."
)
valid = False
return valid
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
"""Validate that MCPToolBlock nodes have required fields.
@@ -913,6 +981,10 @@ class AgentValidator:
"MCP tool blocks",
self.validate_mcp_tool_blocks(agent),
),
(
"SmartDecisionMaker blocks",
self.validate_smart_decision_maker_blocks(agent),
),
]
# Add AgentExecutorBlock detailed validation if library_agents

View File

@@ -22,7 +22,6 @@ from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
from backend.copilot.integration_creds import get_integration_env_vars
from backend.copilot.model import ChatSession
from .base import BaseTool
@@ -97,9 +96,7 @@ class BashExecTool(BaseTool):
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(
sandbox, command, timeout, session_id, user_id
)
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
# Bubblewrap fallback: local isolated execution.
if not has_full_sandbox():
@@ -136,27 +133,14 @@ class BashExecTool(BaseTool):
command: str,
timeout: int,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""Execute *command* on the E2B sandbox via commands.run().
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
for any user with connected accounts. E2B has full internet access, so
CLI tools like ``gh`` work without manual authentication.
"""
envs: dict[str, str] = {
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
}
if user_id is not None:
integration_env = await get_integration_env_vars(user_id)
envs.update(integration_env)
"""Execute *command* on the E2B sandbox via commands.run()."""
try:
result = await sandbox.commands.run(
f"bash -c {shlex.quote(command)}",
cwd=E2B_WORKDIR,
timeout=timeout,
envs=envs,
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
)
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",

View File

@@ -1,78 +0,0 @@
"""Tests for BashExecTool — E2B path with token injection."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from ._test_data import make_session
from .bash_exec import BashExecTool
from .models import BashExecResponse
_USER = "user-bash-exec-test"
def _make_tool() -> BashExecTool:
return BashExecTool()
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
result = MagicMock()
result.exit_code = exit_code
result.stdout = stdout
result.stderr = stderr
sandbox = MagicMock()
sandbox.commands.run = AsyncMock(return_value=result)
return sandbox
class TestBashExecE2BTokenInjection:
@pytest.mark.asyncio(loop_scope="session")
async def test_token_injected_when_user_id_set(self):
"""When user_id is provided, integration env vars are merged into sandbox envs."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value=env_vars),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=_USER,
)
mock_get_env.assert_awaited_once_with(_USER)
call_kwargs = sandbox.commands.run.call_args[1]
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
assert isinstance(result, BashExecResponse)
@pytest.mark.asyncio(loop_scope="session")
async def test_no_token_injection_when_user_id_is_none(self):
"""When user_id is None, get_integration_env_vars must NOT be called."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=None,
)
mock_get_env.assert_not_called()
call_kwargs = sandbox.commands.run.call_args[1]
assert "GH_TOKEN" not in call_kwargs["envs"]
assert isinstance(result, BashExecResponse)

View File

@@ -1,215 +0,0 @@
"""Tool for prompting the user to connect a required integration.
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
"authentication required"), it calls this tool to surface the credentials
setup card in the chat — the same UI that appears when a GitHub block runs
without configured credentials.
"""
import functools
from typing import Any, TypedDict
from backend.copilot.model import ChatSession
from backend.copilot.tools.models import (
ErrorResponse,
ResponseType,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .base import BaseTool
class _ProviderInfo(TypedDict):
name: str
types: list[str]
# Default OAuth scopes requested when the agent doesn't specify any.
scopes: list[str]
class _CredentialEntry(TypedDict):
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
id: str
title: str
provider: str
provider_name: str
type: str
types: list[str]
scopes: list[str]
@functools.lru_cache(maxsize=1)
def _is_github_oauth_configured() -> bool:
"""Return True if GitHub OAuth env vars are set.
Evaluated lazily (not at import time) to avoid triggering Secrets() during
module import, which can fail in environments where secrets are not loaded.
"""
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
return GITHUB_OAUTH_IS_CONFIGURED
# Registry of known providers: name + supported credential types for the UI.
# When adding a new provider, also add its env var names to
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
def _get_provider_info() -> dict[str, _ProviderInfo]:
"""Build the provider registry, evaluating OAuth config lazily."""
return {
"github": {
"name": "GitHub",
"types": (
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
),
# Default: repo scope covers clone/push/pull for public and private repos.
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
"scopes": ["repo"],
},
}
class ConnectIntegrationTool(BaseTool):
"""Surface the credentials setup UI when an integration is not connected."""
@property
def name(self) -> str:
return "connect_integration"
@property
def description(self) -> str:
return (
"Prompt the user to connect a required integration (e.g. GitHub). "
"Call this when an external CLI or API call fails because the user "
"has not connected the relevant account. "
"The tool surfaces a credentials setup card in the chat so the user "
"can authenticate without leaving the page. "
"After the user connects the account, retry the operation. "
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
"automatically injected per-command in bash_exec — no manual export needed. "
"In local bubblewrap mode network is isolated so GitHub CLI commands "
"will still fail after connecting; inform the user of this limitation."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"provider": {
"type": "string",
"description": (
"Integration provider slug, e.g. 'github'. "
"Must be one of the supported providers."
),
"enum": list(_get_provider_info().keys()),
},
"reason": {
"type": "string",
"description": (
"Brief explanation of why the integration is needed, "
"shown to the user in the setup card."
),
"maxLength": 500,
},
"scopes": {
"type": "array",
"items": {"type": "string"},
"description": (
"OAuth scopes to request. Omit to use the provider default. "
"Add extra scopes when you need more access — e.g. for GitHub: "
"'repo' (clone/push/pull), 'read:org' (org membership), "
"'workflow' (GitHub Actions). "
"Requesting only the scopes you actually need is best practice."
),
},
},
"required": ["provider"],
}
@property
def requires_auth(self) -> bool:
# Require auth so only authenticated users can trigger the setup card.
# The card itself is user-agnostic (no per-user data needed), so
# user_id is intentionally unused in _execute.
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
session_id = session.session_id if session else None
provider: str = (kwargs.get("provider") or "").strip().lower()
reason: str = (kwargs.get("reason") or "").strip()[
:500
] # cap LLM-controlled text
extra_scopes: list[str] = [
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
]
provider_info = _get_provider_info()
info = provider_info.get(provider)
if not info:
supported = ", ".join(f"'{p}'" for p in provider_info)
return ErrorResponse(
message=(
f"Unknown provider '{provider}'. "
f"Supported providers: {supported}."
),
error="unknown_provider",
session_id=session_id,
)
provider_name: str = info["name"]
supported_types: list[str] = info["types"]
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
default_scopes: list[str] = info["scopes"]
seen: set[str] = set()
scopes: list[str] = []
for s in default_scopes + extra_scopes:
if s not in seen:
seen.add(s)
scopes.append(s)
field_key = f"{provider}_credentials"
message_parts = [
f"To continue, please connect your {provider_name} account.",
]
if reason:
message_parts.append(reason)
credential_entry: _CredentialEntry = {
"id": field_key,
"title": f"{provider_name} Credentials",
"provider": provider,
"provider_name": provider_name,
"type": supported_types[0],
"types": supported_types,
"scopes": scopes,
}
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
return SetupRequirementsResponse(
type=ResponseType.SETUP_REQUIREMENTS,
message=" ".join(message_parts),
session_id=session_id,
setup_info=SetupInfo(
agent_id=f"connect_{provider}",
agent_name=provider_name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_credentials,
ready_to_run=False,
),
requirements={
"credentials": [missing_credentials[field_key]],
"inputs": [],
"execution_modes": [],
},
),
)

View File

@@ -1,135 +0,0 @@
"""Tests for ConnectIntegrationTool."""
import pytest
from ._test_data import make_session
from .connect_integration import ConnectIntegrationTool
from .models import ErrorResponse, SetupRequirementsResponse
_TEST_USER_ID = "test-user-connect-integration"
class TestConnectIntegrationTool:
def _make_tool(self) -> ConnectIntegrationTool:
return ConnectIntegrationTool()
@pytest.mark.asyncio(loop_scope="session")
async def test_unknown_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
assert "nonexistent" in result.message
assert "github" in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_provider_returns_setup_response(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.setup_info.agent_name == "GitHub"
assert result.setup_info.agent_id == "connect_github"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_has_missing_credentials_in_readiness(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
readiness = result.setup_info.user_readiness
assert readiness.has_all_credentials is False
assert readiness.ready_to_run is False
assert "github_credentials" in readiness.missing_credentials
@pytest.mark.asyncio(loop_scope="session")
async def test_github_requirements_include_credential_entry(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
creds = result.setup_info.requirements["credentials"]
assert len(creds) == 1
assert creds[0]["provider"] == "github"
assert creds[0]["id"] == "github_credentials"
@pytest.mark.asyncio(loop_scope="session")
async def test_reason_appears_in_message(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
reason = "Needed to create a pull request."
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
)
assert isinstance(result, SetupRequirementsResponse)
assert reason in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_session_id_propagated(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.session_id == session.session_id
@pytest.mark.asyncio(loop_scope="session")
async def test_provider_case_insensitive(self):
"""Provider slug is normalised to lowercase before lookup."""
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="GitHub"
)
assert isinstance(result, SetupRequirementsResponse)
def test_tool_name(self):
assert ConnectIntegrationTool().name == "connect_integration"
def test_requires_auth(self):
assert ConnectIntegrationTool().requires_auth is True
@pytest.mark.asyncio(loop_scope="session")
async def test_unauthenticated_user_gets_need_login_response(self):
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
This verifies that the requires_auth guard in BaseTool.execute() fires
before _execute() is called, so unauthenticated callers cannot probe
which integrations are configured.
"""
import json
tool = self._make_tool()
# Session still needs a user_id string; the None is passed to execute()
# to simulate an unauthenticated call.
session = make_session(user_id=_TEST_USER_ID)
result = await tool.execute(
user_id=None,
session=session,
tool_call_id="test-call-id",
provider="github",
)
raw = result.output
output = json.loads(raw) if isinstance(raw, str) else raw
assert output.get("type") == "need_login"
assert result.success is False

View File

@@ -37,7 +37,8 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology;
# usable in agent graphs (guide hardcodes its ID) but cannot run standalone.
"3b191d9f-356f-482d-8238-ba04b6d18381",
}

View File

@@ -25,35 +25,6 @@ logger = logging.getLogger(__name__)
settings = Settings()
_on_creds_changed: Callable[[str, str], None] | None = None
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
"""Register a callback invoked after any credential is created/updated/deleted.
The callback receives ``(user_id, provider)`` and should be idempotent.
Only one hook can be registered at a time; calling this again replaces the
previous hook. Intended to be called once at application startup by the
copilot module to bust its token cache without creating an import cycle.
"""
global _on_creds_changed
_on_creds_changed = hook
def _bust_copilot_cache(user_id: str, provider: str) -> None:
"""Invoke the registered hook (if any) to bust downstream token caches."""
if _on_creds_changed is not None:
try:
_on_creds_changed(user_id, provider)
except Exception:
logger.warning(
"Credential-change hook failed for user=%s provider=%s",
user_id,
provider,
exc_info=True,
)
class IntegrationCredentialsManager:
"""
Handles the lifecycle of integration credentials.
@@ -98,11 +69,7 @@ class IntegrationCredentialsManager:
return self._locks
async def create(self, user_id: str, credentials: Credentials) -> None:
result = await self.store.add_creds(user_id, credentials)
# Bust the copilot token cache so that the next bash_exec picks up the
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
_bust_copilot_cache(user_id, credentials.provider)
return result
return await self.store.add_creds(user_id, credentials)
async def exists(self, user_id: str, credentials_id: str) -> bool:
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
@@ -189,8 +156,6 @@ class IntegrationCredentialsManager:
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
# Bust copilot cache so the refreshed token is picked up immediately.
_bust_copilot_cache(user_id, fresh_credentials.provider)
if _lock and (await _lock.locked()) and (await _lock.owned()):
try:
await _lock.release()
@@ -203,17 +168,10 @@ class IntegrationCredentialsManager:
async def update(self, user_id: str, updated: Credentials) -> None:
async with self._locked(user_id, updated.id):
await self.store.update_creds(user_id, updated)
# Bust the copilot token cache so the updated credential is picked up immediately.
_bust_copilot_cache(user_id, updated.provider)
async def delete(self, user_id: str, credentials_id: str) -> None:
async with self._locked(user_id, credentials_id):
# Read inside the lock to avoid TOCTOU — another coroutine could
# delete the same credential between the read and the delete.
creds = await self.store.get_creds_by_id(user_id, credentials_id)
await self.store.delete_creds_by_id(user_id, credentials_id)
if creds:
_bust_copilot_cache(user_id, creds.provider)
# -- Locking utilities -- #

View File

@@ -0,0 +1,668 @@
"""
Tests for SmartDecisionMakerBlock support in agent generator.
Covers:
- AgentFixer.fix_smart_decision_maker_blocks()
- AgentValidator.validate_smart_decision_maker_blocks()
- End-to-end fix → validate → pipeline for SmartDecisionMaker agents
"""
import uuid
from backend.copilot.tools.agent_generator.fixer import AgentFixer
from backend.copilot.tools.agent_generator.helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
SMART_DECISION_MAKER_BLOCK_ID,
)
from backend.copilot.tools.agent_generator.validator import AgentValidator
def _uid() -> str:
return str(uuid.uuid4())
def _make_sdm_node(
node_id: str | None = None,
input_default: dict | None = None,
metadata: dict | None = None,
) -> dict:
"""Create a SmartDecisionMakerBlock node dict."""
return {
"id": node_id or _uid(),
"block_id": SMART_DECISION_MAKER_BLOCK_ID,
"input_default": input_default or {},
"metadata": metadata or {"position": {"x": 0, "y": 0}},
}
def _make_agent_executor_node(
node_id: str | None = None,
graph_id: str | None = None,
) -> dict:
"""Create an AgentExecutorBlock node dict."""
return {
"id": node_id or _uid(),
"block_id": AGENT_EXECUTOR_BLOCK_ID,
"input_default": {
"graph_id": graph_id or _uid(),
"graph_version": 1,
"input_schema": {"properties": {"query": {"type": "string"}}},
"output_schema": {"properties": {"result": {"type": "string"}}},
"user_id": "",
"inputs": {},
},
"metadata": {"position": {"x": 800, "y": 0}},
}
def _make_input_node(node_id: str | None = None, name: str = "task") -> dict:
return {
"id": node_id or _uid(),
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {"name": name, "title": name.title()},
"metadata": {"position": {"x": -800, "y": 0}},
}
def _make_output_node(node_id: str | None = None, name: str = "result") -> dict:
return {
"id": node_id or _uid(),
"block_id": AGENT_OUTPUT_BLOCK_ID,
"input_default": {"name": name, "title": name.title()},
"metadata": {"position": {"x": 1600, "y": 0}},
}
def _link(
source_id: str,
source_name: str,
sink_id: str,
sink_name: str,
is_static: bool = False,
) -> dict:
return {
"id": _uid(),
"source_id": source_id,
"source_name": source_name,
"sink_id": sink_id,
"sink_name": sink_name,
"is_static": is_static,
}
def _make_orchestrator_agent() -> dict:
"""Build a complete orchestrator agent with SDM + 2 sub-agent tools."""
input_node = _make_input_node()
sdm_node = _make_sdm_node()
agent_a = _make_agent_executor_node()
agent_b = _make_agent_executor_node()
output_node = _make_output_node()
return {
"id": _uid(),
"version": 1,
"is_active": True,
"name": "Orchestrator Agent",
"description": "Uses AI to orchestrate sub-agents",
"nodes": [input_node, sdm_node, agent_a, agent_b, output_node],
"links": [
# Input → SDM prompt
_link(input_node["id"], "result", sdm_node["id"], "prompt"),
# SDM tools → Agent A
_link(sdm_node["id"], "tools", agent_a["id"], "query"),
# SDM tools → Agent B
_link(sdm_node["id"], "tools", agent_b["id"], "query"),
# SDM finished → Output
_link(sdm_node["id"], "finished", output_node["id"], "value"),
],
}
# ---------------------------------------------------------------------------
# Fixer tests
# ---------------------------------------------------------------------------
class TestFixSmartDecisionMakerBlocks:
"""Tests for AgentFixer.fix_smart_decision_maker_blocks()."""
def test_fills_defaults_when_missing(self):
"""All agent-mode defaults are populated for a bare SDM node."""
fixer = AgentFixer()
agent = {"nodes": [_make_sdm_node()], "links": []}
result = fixer.fix_smart_decision_maker_blocks(agent)
defaults = result["nodes"][0]["input_default"]
assert defaults["agent_mode_max_iterations"] == 10
assert defaults["conversation_compaction"] is True
assert defaults["retry"] == 3
assert defaults["multiple_tool_calls"] is False
assert len(fixer.fixes_applied) == 4
def test_preserves_existing_values(self):
"""Existing user-set values are never overwritten."""
fixer = AgentFixer()
agent = {
"nodes": [
_make_sdm_node(
input_default={
"agent_mode_max_iterations": 5,
"conversation_compaction": False,
"retry": 1,
"multiple_tool_calls": True,
}
)
],
"links": [],
}
result = fixer.fix_smart_decision_maker_blocks(agent)
defaults = result["nodes"][0]["input_default"]
assert defaults["agent_mode_max_iterations"] == 5
assert defaults["conversation_compaction"] is False
assert defaults["retry"] == 1
assert defaults["multiple_tool_calls"] is True
assert len(fixer.fixes_applied) == 0
def test_partial_defaults(self):
"""Only missing fields are filled; existing ones are kept."""
fixer = AgentFixer()
agent = {
"nodes": [
_make_sdm_node(
input_default={
"agent_mode_max_iterations": 10,
}
)
],
"links": [],
}
result = fixer.fix_smart_decision_maker_blocks(agent)
defaults = result["nodes"][0]["input_default"]
assert defaults["agent_mode_max_iterations"] == 10 # kept
assert defaults["conversation_compaction"] is True # filled
assert defaults["retry"] == 3 # filled
assert defaults["multiple_tool_calls"] is False # filled
assert len(fixer.fixes_applied) == 3
def test_skips_non_sdm_nodes(self):
"""Non-SmartDecisionMaker nodes are untouched."""
fixer = AgentFixer()
other_node = {
"id": _uid(),
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {"name": "test"},
"metadata": {},
}
agent = {"nodes": [other_node], "links": []}
result = fixer.fix_smart_decision_maker_blocks(agent)
assert "agent_mode_max_iterations" not in result["nodes"][0]["input_default"]
assert len(fixer.fixes_applied) == 0
def test_handles_missing_input_default(self):
"""Node with no input_default key gets one created."""
fixer = AgentFixer()
node = {
"id": _uid(),
"block_id": SMART_DECISION_MAKER_BLOCK_ID,
"metadata": {},
}
agent = {"nodes": [node], "links": []}
result = fixer.fix_smart_decision_maker_blocks(agent)
assert "input_default" in result["nodes"][0]
assert result["nodes"][0]["input_default"]["agent_mode_max_iterations"] == 10
def test_handles_none_input_default(self):
"""Node with input_default set to None gets a dict created."""
fixer = AgentFixer()
node = {
"id": _uid(),
"block_id": SMART_DECISION_MAKER_BLOCK_ID,
"input_default": None,
"metadata": {},
}
agent = {"nodes": [node], "links": []}
result = fixer.fix_smart_decision_maker_blocks(agent)
assert isinstance(result["nodes"][0]["input_default"], dict)
assert result["nodes"][0]["input_default"]["agent_mode_max_iterations"] == 10
def test_treats_none_values_as_missing(self):
"""Explicit None values are overwritten with defaults."""
fixer = AgentFixer()
agent = {
"nodes": [
_make_sdm_node(
input_default={
"agent_mode_max_iterations": None,
"conversation_compaction": None,
"retry": 3,
"multiple_tool_calls": False,
}
)
],
"links": [],
}
result = fixer.fix_smart_decision_maker_blocks(agent)
defaults = result["nodes"][0]["input_default"]
assert defaults["agent_mode_max_iterations"] == 10 # None → default
assert defaults["conversation_compaction"] is True # None → default
assert defaults["retry"] == 3 # kept
assert defaults["multiple_tool_calls"] is False # kept
assert len(fixer.fixes_applied) == 2
def test_multiple_sdm_nodes(self):
"""Multiple SDM nodes are all fixed independently."""
fixer = AgentFixer()
agent = {
"nodes": [
_make_sdm_node(input_default={"agent_mode_max_iterations": 3}),
_make_sdm_node(input_default={}),
],
"links": [],
}
result = fixer.fix_smart_decision_maker_blocks(agent)
# First node: 3 defaults filled (agent_mode was already set)
assert result["nodes"][0]["input_default"]["agent_mode_max_iterations"] == 3
# Second node: all 4 defaults filled
assert result["nodes"][1]["input_default"]["agent_mode_max_iterations"] == 10
assert len(fixer.fixes_applied) == 7 # 3 + 4
def test_registered_in_apply_all_fixes(self):
"""fix_smart_decision_maker_blocks runs as part of apply_all_fixes."""
fixer = AgentFixer()
agent = {
"nodes": [_make_sdm_node()],
"links": [],
}
result = fixer.apply_all_fixes(agent)
defaults = result["nodes"][0]["input_default"]
assert defaults["agent_mode_max_iterations"] == 10
assert any("SmartDecisionMakerBlock" in fix for fix in fixer.fixes_applied)
# ---------------------------------------------------------------------------
# Validator tests
# ---------------------------------------------------------------------------
class TestValidateSmartDecisionMakerBlocks:
"""Tests for AgentValidator.validate_smart_decision_maker_blocks()."""
def test_valid_sdm_with_tools(self):
"""SDM with downstream tool links passes validation."""
validator = AgentValidator()
agent = _make_orchestrator_agent()
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is True
assert len(validator.errors) == 0
def test_sdm_without_tools_fails(self):
"""SDM with no 'tools' links fails validation."""
validator = AgentValidator()
sdm = _make_sdm_node()
agent = {
"nodes": [sdm],
"links": [], # no tool links
}
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is False
assert len(validator.errors) == 1
assert "no downstream tool blocks" in validator.errors[0]
def test_sdm_with_non_tools_links_fails(self):
"""Links that don't use source_name='tools' don't count."""
validator = AgentValidator()
sdm = _make_sdm_node()
other = _make_agent_executor_node()
agent = {
"nodes": [sdm, other],
"links": [
# Link from 'finished' output, not 'tools'
_link(sdm["id"], "finished", other["id"], "query"),
],
}
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is False
assert len(validator.errors) == 1
def test_no_sdm_nodes_passes(self):
"""Agent without SmartDecisionMaker nodes passes trivially."""
validator = AgentValidator()
agent = {
"nodes": [_make_input_node(), _make_output_node()],
"links": [],
}
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is True
assert len(validator.errors) == 0
def test_error_includes_customized_name(self):
"""Error message includes the node's customized_name if set."""
validator = AgentValidator()
sdm = _make_sdm_node(
metadata={
"position": {"x": 0, "y": 0},
"customized_name": "My Orchestrator",
}
)
agent = {"nodes": [sdm], "links": []}
validator.validate_smart_decision_maker_blocks(agent)
assert "My Orchestrator" in validator.errors[0]
def test_multiple_sdm_nodes_mixed(self):
"""One valid and one invalid SDM node: only the invalid one errors."""
validator = AgentValidator()
sdm_valid = _make_sdm_node()
sdm_invalid = _make_sdm_node()
tool = _make_agent_executor_node()
agent = {
"nodes": [sdm_valid, sdm_invalid, tool],
"links": [
_link(sdm_valid["id"], "tools", tool["id"], "query"),
# sdm_invalid has no tool links
],
}
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is False
assert len(validator.errors) == 1
assert sdm_invalid["id"] in validator.errors[0]
def test_sdm_with_traditional_mode_fails(self):
"""agent_mode_max_iterations=0 (traditional mode) is rejected."""
validator = AgentValidator()
sdm = _make_sdm_node(input_default={"agent_mode_max_iterations": 0})
tool = _make_agent_executor_node()
agent = {
"nodes": [sdm, tool],
"links": [_link(sdm["id"], "tools", tool["id"], "query")],
}
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is False
assert any("agent_mode_max_iterations=0" in e for e in validator.errors)
def test_sdm_with_negative_iterations_below_minus_one_fails(self):
"""agent_mode_max_iterations < -1 is rejected."""
validator = AgentValidator()
sdm = _make_sdm_node(input_default={"agent_mode_max_iterations": -5})
tool = _make_agent_executor_node()
agent = {
"nodes": [sdm, tool],
"links": [_link(sdm["id"], "tools", tool["id"], "query")],
}
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is False
assert any("invalid" in e and "-5" in e for e in validator.errors)
def test_sdm_with_only_interface_block_links_fails(self):
"""Links to AgentInput/OutputBlocks don't count as tool connections."""
validator = AgentValidator()
sdm = _make_sdm_node()
input_node = _make_input_node()
output_node = _make_output_node()
agent = {
"nodes": [sdm, input_node, output_node],
"links": [
# These link to interface blocks, not real tools
_link(sdm["id"], "tools", input_node["id"], "name"),
_link(sdm["id"], "tools", output_node["id"], "value"),
],
}
result = validator.validate_smart_decision_maker_blocks(agent)
assert result is False
assert len(validator.errors) == 1
assert "no downstream tool blocks" in validator.errors[0]
def test_registered_in_validate(self):
"""validate_smart_decision_maker_blocks runs as part of validate()."""
validator = AgentValidator()
sdm = _make_sdm_node()
agent = {
"id": _uid(),
"version": 1,
"is_active": True,
"name": "Test",
"description": "test",
"nodes": [sdm, _make_input_node(), _make_output_node()],
"links": [],
}
# Build a minimal blocks list with the SDM block info
blocks = [
{
"id": SMART_DECISION_MAKER_BLOCK_ID,
"name": "SmartDecisionMakerBlock",
"inputSchema": {"properties": {"prompt": {"type": "string"}}},
"outputSchema": {
"properties": {
"tools": {},
"finished": {"type": "string"},
"conversations": {"type": "array"},
}
},
},
{
"id": AGENT_INPUT_BLOCK_ID,
"name": "AgentInputBlock",
"inputSchema": {
"properties": {"name": {"type": "string"}},
"required": ["name"],
},
"outputSchema": {"properties": {"result": {}}},
},
{
"id": AGENT_OUTPUT_BLOCK_ID,
"name": "AgentOutputBlock",
"inputSchema": {
"properties": {
"name": {"type": "string"},
"value": {},
},
"required": ["name"],
},
"outputSchema": {"properties": {"output": {}}},
},
]
is_valid, error_msg = validator.validate(agent, blocks)
assert is_valid is False
assert error_msg is not None
assert "no downstream tool blocks" in error_msg
# ---------------------------------------------------------------------------
# E2E pipeline test: fix → validate for a complete orchestrator agent
# ---------------------------------------------------------------------------
class TestSmartDecisionMakerE2EPipeline:
"""End-to-end tests: build agent JSON → fix → validate."""
def test_orchestrator_agent_fix_then_validate(self):
"""A well-formed orchestrator agent passes fix + validate."""
agent = _make_orchestrator_agent()
# Fix
fixer = AgentFixer()
fixed = fixer.apply_all_fixes(agent)
# Verify defaults were applied
sdm_nodes = [
n for n in fixed["nodes"] if n["block_id"] == SMART_DECISION_MAKER_BLOCK_ID
]
assert len(sdm_nodes) == 1
assert sdm_nodes[0]["input_default"]["agent_mode_max_iterations"] == 10
assert sdm_nodes[0]["input_default"]["conversation_compaction"] is True
# Validate (standalone SDM check)
validator = AgentValidator()
assert validator.validate_smart_decision_maker_blocks(fixed) is True
def test_bare_sdm_no_tools_fix_then_validate(self):
"""SDM without tools: fixer fills defaults, validator catches error."""
input_node = _make_input_node()
sdm_node = _make_sdm_node()
output_node = _make_output_node()
agent = {
"id": _uid(),
"version": 1,
"is_active": True,
"name": "Bare SDM Agent",
"description": "SDM with no tools",
"nodes": [input_node, sdm_node, output_node],
"links": [
_link(input_node["id"], "result", sdm_node["id"], "prompt"),
_link(sdm_node["id"], "finished", output_node["id"], "value"),
],
}
# Fix fills defaults fine
fixer = AgentFixer()
fixed = fixer.apply_all_fixes(agent)
assert fixed["nodes"][1]["input_default"]["agent_mode_max_iterations"] == 10
# Validate catches missing tools
validator = AgentValidator()
assert validator.validate_smart_decision_maker_blocks(fixed) is False
assert any("no downstream tool blocks" in e for e in validator.errors)
def test_sdm_with_user_set_bounded_iterations(self):
"""User-set bounded iterations are preserved through fix pipeline."""
agent = _make_orchestrator_agent()
# Simulate user setting bounded iterations
for node in agent["nodes"]:
if node["block_id"] == SMART_DECISION_MAKER_BLOCK_ID:
node["input_default"]["agent_mode_max_iterations"] = 5
node["input_default"]["sys_prompt"] = "You are a helpful orchestrator"
fixer = AgentFixer()
fixed = fixer.apply_all_fixes(agent)
sdm = next(
n for n in fixed["nodes"] if n["block_id"] == SMART_DECISION_MAKER_BLOCK_ID
)
assert sdm["input_default"]["agent_mode_max_iterations"] == 5
assert sdm["input_default"]["sys_prompt"] == "You are a helpful orchestrator"
# Other defaults still filled
assert sdm["input_default"]["conversation_compaction"] is True
assert sdm["input_default"]["retry"] == 3
def test_full_pipeline_with_blocks_list(self):
"""Full validate() with blocks list for a valid orchestrator agent."""
agent = _make_orchestrator_agent()
fixer = AgentFixer()
fixed = fixer.apply_all_fixes(agent)
blocks = [
{
"id": SMART_DECISION_MAKER_BLOCK_ID,
"name": "SmartDecisionMakerBlock",
"inputSchema": {
"properties": {
"prompt": {"type": "string"},
"model": {"type": "object"},
"sys_prompt": {"type": "string"},
"agent_mode_max_iterations": {"type": "integer"},
"conversation_compaction": {"type": "boolean"},
"retry": {"type": "integer"},
"multiple_tool_calls": {"type": "boolean"},
},
"required": ["prompt"],
},
"outputSchema": {
"properties": {
"tools": {},
"finished": {"type": "string"},
"conversations": {"type": "array"},
}
},
},
{
"id": AGENT_EXECUTOR_BLOCK_ID,
"name": "AgentExecutorBlock",
"inputSchema": {
"properties": {
"graph_id": {"type": "string"},
"graph_version": {"type": "integer"},
"input_schema": {"type": "object"},
"output_schema": {"type": "object"},
"user_id": {"type": "string"},
"inputs": {"type": "object"},
"query": {"type": "string"},
},
"required": ["graph_id"],
},
"outputSchema": {
"properties": {"result": {"type": "string"}},
},
},
{
"id": AGENT_INPUT_BLOCK_ID,
"name": "AgentInputBlock",
"inputSchema": {
"properties": {"name": {"type": "string"}},
"required": ["name"],
},
"outputSchema": {"properties": {"result": {}}},
},
{
"id": AGENT_OUTPUT_BLOCK_ID,
"name": "AgentOutputBlock",
"inputSchema": {
"properties": {
"name": {"type": "string"},
"value": {},
},
"required": ["name"],
},
"outputSchema": {"properties": {"output": {}}},
},
]
validator = AgentValidator()
is_valid, error_msg = validator.validate(fixed, blocks)
# Full graph validation should pass
assert is_valid, f"Validation failed: {error_msg}"
# SDM-specific validation should pass (has tool links)
sdm_errors = [e for e in validator.errors if "SmartDecisionMakerBlock" in e]
assert len(sdm_errors) == 0, f"Unexpected SDM errors: {sdm_errors}"

View File

@@ -3,7 +3,6 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { ExclamationMarkIcon } from "@phosphor-icons/react";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useState } from "react";
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
import {
@@ -130,8 +129,6 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
case "tool-search_docs":
case "tool-get_doc_page":
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
case "tool-connect_integration":
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
case "tool-run_block":
case "tool-continue_run_block":
return <RunBlockTool key={key} part={part as ToolUIPart} />;

View File

@@ -1,104 +0,0 @@
"use client";
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
import type { ToolUIPart } from "ai";
import { useState } from "react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
type Props = {
part: ToolUIPart;
};
function parseJson(raw: unknown): unknown {
if (typeof raw === "string") {
try {
return JSON.parse(raw);
} catch {
return null;
}
}
return raw;
}
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
return parsed as SetupRequirementsResponse;
}
return null;
}
function parseError(raw: unknown): string | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "message" in parsed) {
return String((parsed as { message: unknown }).message);
}
return null;
}
export function ConnectIntegrationTool({ part }: Props) {
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
const [isDismissed, setIsDismissed] = useState(false);
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
const output =
part.state === "output-available"
? parseOutput((part as { output?: unknown }).output)
: null;
const errorMessage = isError
? (parseError((part as { output?: unknown }).output) ??
"Failed to connect integration")
: null;
const rawProvider =
(part as { input?: { provider?: string } }).input?.provider ?? "";
const providerName =
output?.setup_info?.agent_name ??
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
// prevent runaway text in the DOM.
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
const label = isStreaming
? `Connecting ${providerName}`
: isError
? `Failed to connect ${providerName}`
: output
? `Connect ${output.setup_info?.agent_name ?? providerName}`
: `Connect ${providerName}`;
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<MorphingTextAnimation
text={label}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isError && errorMessage && (
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
)}
{output && (
<div className="mt-2">
{isDismissed ? (
<ContentMessage>Connected. Continuing</ContentMessage>
) : (
<SetupRequirementsCard
output={output}
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
retryInstruction="I've connected my account. Please continue."
onComplete={() => setIsDismissed(true)}
/>
)}
</div>
)}
</div>
);
}

View File

@@ -23,16 +23,12 @@ interface Props {
/** Override the label shown above the credentials section.
* Defaults to "Credentials". */
credentialsLabel?: string;
/** Called after Proceed is clicked so the parent can persist the dismissed state
* across remounts (avoids re-enabling the Proceed button on remount). */
onComplete?: () => void;
}
export function SetupRequirementsCard({
output,
retryInstruction,
credentialsLabel,
onComplete,
}: Props) {
const { onSend } = useCopilotChatActions();
@@ -72,17 +68,13 @@ export function SetupRequirementsCard({
return v !== undefined && v !== null && v !== "";
});
if (hasSent) {
return <ContentMessage>Connected. Continuing</ContentMessage>;
}
const canRun =
!hasSent &&
(!needsCredentials || isAllCredentialsComplete) &&
(!needsInputs || isAllInputsComplete);
function handleRun() {
setHasSent(true);
onComplete?.();
const parts: string[] = [];
if (needsCredentials) {

View File

@@ -125,9 +125,9 @@ export function useCredentialsInput({
if (hasAttemptedAutoSelect.current) return;
hasAttemptedAutoSelect.current = true;
// Auto-select only when there is exactly one saved credential.
// With multiple options the user must choose — regardless of optional/required.
if (savedCreds.length > 1) return;
// Auto-select if exactly one credential matches.
// For optional fields with multiple options, let the user choose.
if (isOptional && savedCreds.length > 1) return;
const cred = savedCreds[0];
onSelectCredential({