mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
8 Commits
feat/githu
...
abhi/build
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c41a99b76c | ||
|
|
010fc5f0aa | ||
|
|
0f622ff49f | ||
|
|
1e90bcc6d3 | ||
|
|
b8d076b7e4 | ||
|
|
1c18d2e991 | ||
|
|
83a1578746 | ||
|
|
e6d581157f |
@@ -11,10 +11,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
@@ -181,8 +178,7 @@ class FileOperation(StrEnum):
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
content: str
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
@@ -279,11 +275,11 @@ class GithubMultiFileCommitBlock(Block):
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
async def _create_blob(content: str) -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": encoding},
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
@@ -305,19 +301,10 @@ class GithubMultiFileCommitBlock(Block):
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
# Create all blobs concurrently
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
@@ -371,36 +358,15 @@ class GithubMultiFileCommitBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
resolved_files,
|
||||
input_data.files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
@@ -98,11 +97,7 @@ async def test_multi_file_commit_error_path():
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
@@ -84,17 +82,6 @@ def resolve_sandbox_path(path: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||
|
||||
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||
``tools/__init__`` import chain.
|
||||
"""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps provider slug → env var names to inject when the provider is connected.
|
||||
# Add new providers here when adding integration support.
|
||||
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
|
||||
# must be updated when adding a new provider.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"falling back to potentially stale token",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -1,193 +0,0 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "stale-oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -52,68 +52,17 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
**Type coercion**: The platform automatically coerces expanded string values
|
||||
to match the block's expected input types. For example, if a block expects
|
||||
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
|
||||
an @@agptfile: expansion), the string will be parsed into the correct type.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
@@ -124,7 +73,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -139,7 +87,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -173,16 +120,12 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -200,11 +143,7 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -219,7 +158,6 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,45 +3,12 @@
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
|
||||
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||
circular import cycle::
|
||||
|
||||
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||
|
||||
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||
function scope without a larger refactor (moving tool-name registration
|
||||
to a separate lightweight module). The lazy-import pattern here is the
|
||||
least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
entry = _LAZY_IMPORTS.get(name)
|
||||
if entry is not None:
|
||||
module_path, attr = entry
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path, package=__name__)
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -41,20 +41,12 @@ from typing import Any
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
MIME_TO_FORMAT,
|
||||
PARSE_EXCEPTIONS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
@@ -82,8 +74,6 @@ _FILE_REF_RE = re.compile(
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||
_MAX_BARE_REF_BYTES = 10_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,11 +83,6 @@ class FileRef:
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
@@ -119,6 +104,17 @@ def parse_file_ref(text: str) -> FileRef | None:
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
@@ -134,47 +130,27 @@ async def read_file_bytes(
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
data = await (
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except (PermissionError, OSError) as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||
# AttributeError/TypeError: workspace manager returned an
|
||||
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
# NOTE: Workspace API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||
with open(resolved, "rb") as fh:
|
||||
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||
f"limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except OSError as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
@@ -186,33 +162,9 @@ async def read_file_bytes(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||
# inherit from standard exceptions. Import lazily to avoid a
|
||||
# hard dependency on e2b at module level.
|
||||
try:
|
||||
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||
|
||||
if isinstance(exc, SandboxException):
|
||||
raise ValueError(
|
||||
f"Failed to read from sandbox: {plain}: {exc}"
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||
# so they surface as real bugs rather than being silently masked.
|
||||
raise
|
||||
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
@@ -226,13 +178,15 @@ async def resolve_file_ref(
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
@@ -278,9 +232,6 @@ async def expand_file_refs_in_string(
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
@@ -301,31 +252,13 @@ async def expand_file_refs_in_string(
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
**Bare references** (the entire argument value is a single
|
||||
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||
parsed according to the file's extension or MIME type. See
|
||||
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||
|
||||
When *input_schema* is provided and the target property has
|
||||
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||
is returned as a plain string so blocks receive the original text.
|
||||
|
||||
If the format is unrecognised or parsing fails, the content is returned as
|
||||
a plain string (the fallback).
|
||||
|
||||
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||
produce a plain string — structured parsing only applies to bare refs.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
@@ -334,382 +267,15 @@ async def expand_file_refs_in_args(
|
||||
if not args:
|
||||
return args
|
||||
|
||||
properties = (input_schema or {}).get("properties", {})
|
||||
|
||||
async def _expand(
|
||||
value: Any,
|
||||
*,
|
||||
prop_schema: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Recursively expand a single argument value.
|
||||
|
||||
Strings are checked for ``@@agptfile:`` references and expanded
|
||||
(bare refs get structured parsing; embedded refs get inline
|
||||
substitution). Dicts and lists are traversed recursively,
|
||||
threading the corresponding sub-schema from *prop_schema* so
|
||||
that nested fields also receive correct type-aware expansion.
|
||||
Non-string scalars pass through unchanged.
|
||||
"""
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
ref = parse_file_ref(value)
|
||||
if ref is not None:
|
||||
# MediaFileType fields: return the raw URI immediately —
|
||||
# no file reading, no format inference, no content parsing.
|
||||
if _is_media_file_field(prop_schema):
|
||||
return ref.uri
|
||||
|
||||
fmt = infer_format_from_uri(ref.uri)
|
||||
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||
# When the MIME fragment is also missing, fall back to the
|
||||
# workspace file manager's metadata for format detection.
|
||||
if fmt is None and ref.uri.startswith("workspace://"):
|
||||
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||
|
||||
# Not a bare ref — do normal inline expansion.
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
# When the schema says this is an object but doesn't define
|
||||
# inner properties, skip expansion — the caller (e.g.
|
||||
# RunBlockTool) will expand with the actual nested schema.
|
||||
if (
|
||||
prop_schema is not None
|
||||
and prop_schema.get("type") == "object"
|
||||
and "properties" not in prop_schema
|
||||
):
|
||||
return value
|
||||
nested_props = (prop_schema or {}).get("properties", {})
|
||||
return {
|
||||
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||
for k, v in value.items()
|
||||
}
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
items_schema = (prop_schema or {}).get("items")
|
||||
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers (used by the public functions above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
def _to_str(content: str | bytes) -> str:
|
||||
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _check_content_size(content: str | bytes) -> None:
|
||||
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||
|
||||
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||
|
||||
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||
mean the byte size can be up to 4x the character count.
|
||||
"""
|
||||
if isinstance(content, bytes):
|
||||
size = len(content)
|
||||
else:
|
||||
char_len = len(content)
|
||||
# Fast lower bound: UTF-8 byte count >= char count.
|
||||
# If char count already exceeds the limit, reject immediately
|
||||
# without allocating an encoded copy.
|
||||
if char_len > _MAX_BARE_REF_BYTES:
|
||||
size = char_len # real byte size is even larger
|
||||
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||
# If worst-case is still under the limit, skip encoding entirely.
|
||||
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||
return
|
||||
else:
|
||||
# Edge case: char count is under limit but multibyte chars
|
||||
# might push byte count over. Encode to get exact size.
|
||||
size = len(content.encode("utf-8"))
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large for structured parsing "
|
||||
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
|
||||
|
||||
async def _infer_format_from_workspace(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str | None:
|
||||
"""Look up workspace file metadata to infer the format.
|
||||
|
||||
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||
When the MIME fragment is also absent, we query the workspace file
|
||||
manager for the file's stored MIME type and original filename.
|
||||
"""
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
ws = parse_workspace_uri(uri)
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
info = await (
|
||||
manager.get_file_info(ws.file_ref)
|
||||
if not ws.is_path
|
||||
else manager.get_file_info_by_path(ws.file_ref)
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
# Try MIME type first, then filename extension.
|
||||
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||
except (
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
OSError,
|
||||
PermissionError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
):
|
||||
# Expected failures: bad URI, missing file, permission denied, or
|
||||
# workspace manager returning unexpected types. Propagate anything
|
||||
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||
if prop_schema is None:
|
||||
return False
|
||||
return (
|
||||
prop_schema.get("type") == "string"
|
||||
and prop_schema.get("format") == MediaFileType.string_format
|
||||
)
|
||||
|
||||
|
||||
async def _expand_bare_ref(
|
||||
ref: FileRef,
|
||||
fmt: str | None,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
prop_schema: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||
|
||||
This is the structured-parsing path: the file is read, optionally parsed
|
||||
according to *fmt*, and adapted to the target *prop_schema*.
|
||||
|
||||
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||
|
||||
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||
"""
|
||||
try:
|
||||
if fmt is not None and fmt in BINARY_FORMATS:
|
||||
# Binary formats need raw bytes, not UTF-8 text.
|
||||
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||
# — ignore them and parse full bytes. Warn so the caller/model
|
||||
# knows the range was silently dropped.
|
||||
if ref.start_line is not None or ref.end_line is not None:
|
||||
logger.warning(
|
||||
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||
"binary formats are always parsed in full.",
|
||||
ref.start_line,
|
||||
ref.end_line,
|
||||
fmt,
|
||||
ref.uri,
|
||||
)
|
||||
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||
else:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
# _check_content_size raises ValueError, which we unify here just like
|
||||
# resolution errors above.
|
||||
try:
|
||||
_check_content_size(content)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
# return raw file content — don't parse into a structured
|
||||
# type that would need json.dumps() serialisation.
|
||||
expect_string = (prop_schema or {}).get("type") == "string"
|
||||
if expect_string:
|
||||
if isinstance(content, bytes):
|
||||
raise FileRefExpansionError(
|
||||
f"Cannot use {fmt} file as text input: "
|
||||
f"binary formats (parquet, xlsx) must be passed "
|
||||
f"to a block that accepts structured data (list/object), "
|
||||
f"not a string-typed parameter."
|
||||
)
|
||||
return content
|
||||
|
||||
if fmt is not None:
|
||||
# Use strict mode for binary formats so we surface the
|
||||
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||
# file) instead of silently returning garbled bytes.
|
||||
strict = fmt in BINARY_FORMATS
|
||||
try:
|
||||
parsed = parse_file_content(content, fmt, strict=strict)
|
||||
except PARSE_EXCEPTIONS as exc:
|
||||
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||
# Normalize bytes fallback to str so tools never
|
||||
# receive raw bytes when parsing fails.
|
||||
if isinstance(parsed, bytes):
|
||||
parsed = _to_str(parsed)
|
||||
return _adapt_to_schema(parsed, prop_schema)
|
||||
|
||||
# Unknown format — return as plain string, but apply
|
||||
# the same per-ref character limit used by inline refs
|
||||
# to prevent injecting unexpectedly large content.
|
||||
text = _to_str(content)
|
||||
if len(text) > _MAX_EXPAND_CHARS:
|
||||
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
return text
|
||||
|
||||
|
||||
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
"""Adapt a parsed file value to better fit the target schema type.
|
||||
|
||||
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||
that doesn't match the block's expected type, this function converts it to
|
||||
a more useful representation instead of relying on pydantic's generic
|
||||
coercion (which can produce awkward results like flattened dicts → lists).
|
||||
|
||||
Returns *parsed* unchanged when no adaptation is needed.
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return parsed
|
||||
|
||||
target_type = prop_schema.get("type")
|
||||
|
||||
# Dict → array: delegate to helper.
|
||||
if isinstance(parsed, dict) and target_type == "array":
|
||||
return _adapt_dict_to_array(parsed, prop_schema)
|
||||
|
||||
# List → object: delegate to helper (raises for non-tabular lists).
|
||||
if isinstance(parsed, list) and target_type == "object":
|
||||
return _adapt_list_to_object(parsed)
|
||||
|
||||
# Tabular list → Any (no type): convert to list of dicts.
|
||||
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||
return _tabular_to_list_of_dicts(parsed)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||
"""Adapt a parsed dict to an array-typed field.
|
||||
|
||||
Extracts list-valued entries when the target item type is ``array``,
|
||||
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||
or wraps in ``[parsed]`` as a fallback.
|
||||
"""
|
||||
items_type = (prop_schema.get("items") or {}).get("type")
|
||||
if items_type == "array":
|
||||
# Target is List[List[Any]] — extract list-typed values from the
|
||||
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _adapt_list_to_object(parsed: list) -> Any:
|
||||
"""Adapt a parsed list to an object-typed field.
|
||||
|
||||
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||
"""
|
||||
if _is_tabular(parsed):
|
||||
return _tabular_to_column_dict(parsed)
|
||||
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||
# get a clear error rather than pydantic silently wrapping the list.
|
||||
raise FileRefExpansionError(
|
||||
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||
)
|
||||
|
||||
|
||||
def _is_tabular(parsed: Any) -> bool:
|
||||
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||
|
||||
Uses isinstance checks because this is a structural type guard on
|
||||
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||
help here — we need to verify exact list-of-lists shape.
|
||||
"""
|
||||
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||
return False
|
||||
header = parsed[0]
|
||||
if not isinstance(header, list) or not header:
|
||||
return False
|
||||
if not all(isinstance(h, str) for h in header):
|
||||
return False
|
||||
return all(isinstance(row, list) for row in parsed[1:])
|
||||
|
||||
|
||||
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values.
|
||||
Extra values beyond the header length are silently dropped.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return [
|
||||
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||
for row in parsed[1:]
|
||||
]
|
||||
|
||||
|
||||
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values,
|
||||
ensuring all columns have equal length.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return {
|
||||
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||
for i, col in enumerate(header)
|
||||
}
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
|
||||
@@ -175,199 +175,6 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — bare ref structured parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_json_returns_parsed_dict():
|
||||
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value", "count": 42}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_csv_returns_parsed_table():
|
||||
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||
with open(csv_file, "w") as f:
|
||||
f.write("Name,Score\nAlice,90\nBob,85")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"input": f"@@agptfile:{csv_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["input"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_unknown_extension_returns_string():
|
||||
"""Bare ref to a file with unknown extension returns plain string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||
with open(txt_file, "w") as f:
|
||||
f.write("plain text content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{txt_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "plain text content"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "not valid json {{{"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value"}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert isinstance(result["data"], str)
|
||||
assert result["data"].startswith("prefix ")
|
||||
assert result["data"].endswith(" suffix")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||
"""Bare ref to a .yaml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||
with open(yaml_file, "w") as f:
|
||||
f.write("name: test\ncount: 42\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{yaml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||
|
||||
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||
parsed and the range is silently dropped with a log warning.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pytest.skip("pandas not installed")
|
||||
try:
|
||||
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
except ImportError:
|
||||
pytest.skip("pyarrow not installed")
|
||||
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||
import io as _io
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = _io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
with open(parquet_file, "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
# Line range [1-2] should be silently ignored for binary formats.
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
# Full file is returned despite the line range.
|
||||
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_toml_returns_parsed_dict():
|
||||
"""Bare ref to a .toml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||
with open(toml_file, "w") as f:
|
||||
f.write('name = "test"\ncount = 42\n')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{toml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -412,7 +219,7 @@ async def test_read_file_handler_workspace_uri():
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
@@ -469,7 +276,7 @@ async def test_read_file_bytes_workspace_virtual_path():
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,6 @@ from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -63,6 +62,7 @@ from ..service import (
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tools.workspace_files import get_manager
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
@@ -565,7 +565,7 @@ async def _prepare_file_attachments(
|
||||
return empty
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create workspace manager for file attachments",
|
||||
@@ -769,7 +769,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
return None
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
@@ -783,9 +783,7 @@ async def stream_chat_completion_sdk(
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return sandbox
|
||||
return None
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
|
||||
@@ -20,7 +20,7 @@ class _FakeFileInfo:
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
|
||||
@@ -347,7 +347,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
"""
|
||||
|
||||
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
@@ -361,9 +361,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
args = await expand_file_refs_in_args(args, user_id, session)
|
||||
except FileRefExpansionError as exc:
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
@@ -391,12 +389,11 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
schema = _build_input_schema(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
schema,
|
||||
)(_truncating(handler, tool_name, input_schema=schema))
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
|
||||
@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -85,7 +84,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
@@ -32,7 +32,6 @@ import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
@@ -44,6 +43,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -194,7 +194,7 @@ async def _save_browser_state(
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
@@ -218,7 +218,7 @@ async def _restore_browser_state(
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_name)
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestHasLocalSession:
|
||||
# _save_browser_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
|
||||
|
||||
|
||||
def _make_mock_manager():
|
||||
|
||||
@@ -22,7 +22,6 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -97,9 +96,7 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -136,27 +133,14 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
envs.update(integration_env)
|
||||
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs=envs,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Tests for BashExecTool — E2B path with token injection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .bash_exec import BashExecTool
|
||||
from .models import BashExecResponse
|
||||
|
||||
_USER = "user-bash-exec-test"
|
||||
|
||||
|
||||
def _make_tool() -> BashExecTool:
|
||||
return BashExecTool()
|
||||
|
||||
|
||||
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.exit_code = exit_code
|
||||
result.stdout = stdout
|
||||
result.stderr = stderr
|
||||
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=result)
|
||||
return sandbox
|
||||
|
||||
|
||||
class TestBashExecE2BTokenInjection:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_token_injected_when_user_id_set(self):
|
||||
"""When user_id is provided, integration env vars are merged into sandbox envs."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
mock_get_env.assert_awaited_once_with(_USER)
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Tool for prompting the user to connect a required integration.
|
||||
|
||||
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
|
||||
"authentication required"), it calls this tool to surface the credentials
|
||||
setup card in the chat — the same UI that appears when a GitHub block runs
|
||||
without configured credentials.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class _ProviderInfo(TypedDict):
|
||||
name: str
|
||||
types: list[str]
|
||||
# Default OAuth scopes requested when the agent doesn't specify any.
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class _CredentialEntry(TypedDict):
|
||||
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
provider: str
|
||||
provider_name: str
|
||||
type: str
|
||||
types: list[str]
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Evaluated lazily (not at import time) to avoid triggering Secrets() during
|
||||
module import, which can fail in environments where secrets are not loaded.
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# Registry of known providers: name + supported credential types for the UI.
|
||||
# When adding a new provider, also add its env var names to
|
||||
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
|
||||
def _get_provider_info() -> dict[str, _ProviderInfo]:
|
||||
"""Build the provider registry, evaluating OAuth config lazily."""
|
||||
return {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"types": (
|
||||
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
|
||||
),
|
||||
# Default: repo scope covers clone/push/pull for public and private repos.
|
||||
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
|
||||
"scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConnectIntegrationTool(BaseTool):
|
||||
"""Surface the credentials setup UI when an integration is not connected."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "connect_integration"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prompt the user to connect a required integration (e.g. GitHub). "
|
||||
"Call this when an external CLI or API call fails because the user "
|
||||
"has not connected the relevant account. "
|
||||
"The tool surfaces a credentials setup card in the chat so the user "
|
||||
"can authenticate without leaving the page. "
|
||||
"After the user connects the account, retry the operation. "
|
||||
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
|
||||
"automatically injected per-command in bash_exec — no manual export needed. "
|
||||
"In local bubblewrap mode network is isolated so GitHub CLI commands "
|
||||
"will still fail after connecting; inform the user of this limitation."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Integration provider slug, e.g. 'github'. "
|
||||
"Must be one of the supported providers."
|
||||
),
|
||||
"enum": list(_get_provider_info().keys()),
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Brief explanation of why the integration is needed, "
|
||||
"shown to the user in the setup card."
|
||||
),
|
||||
"maxLength": 500,
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"OAuth scopes to request. Omit to use the provider default. "
|
||||
"Add extra scopes when you need more access — e.g. for GitHub: "
|
||||
"'repo' (clone/push/pull), 'read:org' (org membership), "
|
||||
"'workflow' (GitHub Actions). "
|
||||
"Requesting only the scopes you actually need is best practice."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["provider"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
# Require auth so only authenticated users can trigger the setup card.
|
||||
# The card itself is user-agnostic (no per-user data needed), so
|
||||
# user_id is intentionally unused in _execute.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
provider_info = _get_provider_info()
|
||||
info = provider_info.get(provider)
|
||||
if not info:
|
||||
supported = ", ".join(f"'{p}'" for p in provider_info)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
provider_name: str = info["name"]
|
||||
supported_types: list[str] = info["types"]
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = info["scopes"]
|
||||
seen: set[str] = set()
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
f"To continue, please connect your {provider_name} account.",
|
||||
]
|
||||
if reason:
|
||||
message_parts.append(reason)
|
||||
|
||||
credential_entry: _CredentialEntry = {
|
||||
"id": field_key,
|
||||
"title": f"{provider_name} Credentials",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"type": supported_types[0],
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
type=ResponseType.SETUP_REQUIREMENTS,
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=f"connect_{provider}",
|
||||
agent_name=provider_name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [missing_credentials[field_key]],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Tests for ConnectIntegrationTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .models import ErrorResponse, SetupRequirementsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-connect-integration"
|
||||
|
||||
|
||||
class TestConnectIntegrationTool:
|
||||
def _make_tool(self) -> ConnectIntegrationTool:
|
||||
return ConnectIntegrationTool()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
assert "nonexistent" in result.message
|
||||
assert "github" in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_provider_returns_setup_response(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_name == "GitHub"
|
||||
assert result.setup_info.agent_id == "connect_github"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_has_missing_credentials_in_readiness(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
readiness = result.setup_info.user_readiness
|
||||
assert readiness.has_all_credentials is False
|
||||
assert readiness.ready_to_run is False
|
||||
assert "github_credentials" in readiness.missing_credentials
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_requirements_include_credential_entry(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
creds = result.setup_info.requirements["credentials"]
|
||||
assert len(creds) == 1
|
||||
assert creds[0]["provider"] == "github"
|
||||
assert creds[0]["id"] == "github_credentials"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_reason_appears_in_message(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
reason = "Needed to create a pull request."
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert reason in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_session_id_propagated(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_provider_case_insensitive(self):
|
||||
"""Provider slug is normalised to lowercase before lookup."""
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="GitHub"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
def test_tool_name(self):
|
||||
assert ConnectIntegrationTool().name == "connect_integration"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert ConnectIntegrationTool().requires_auth is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unauthenticated_user_gets_need_login_response(self):
|
||||
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
|
||||
|
||||
This verifies that the requires_auth guard in BaseTool.execute() fires
|
||||
before _execute() is called, so unauthenticated callers cannot probe
|
||||
which integrations are configured.
|
||||
"""
|
||||
import json
|
||||
|
||||
tool = self._make_tool()
|
||||
# Session still needs a user_id string; the None is passed to execute()
|
||||
# to simulate an unauthenticated call.
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id="test-call-id",
|
||||
provider="github",
|
||||
)
|
||||
raw = result.output
|
||||
output = json.loads(raw) if isinstance(raw, str) else raw
|
||||
assert output.get("type") == "need_login"
|
||||
assert result.success is False
|
||||
@@ -12,7 +12,6 @@ from backend.copilot.constants import (
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
@@ -198,29 +197,6 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Expand @@agptfile: refs in input_data with the block's input
|
||||
# schema. The generic _truncating wrapper skips opaque object
|
||||
# properties (input_data has no declared inner properties in the
|
||||
# tool schema), so file ref tokens are still intact here.
|
||||
# Using the block's schema lets us return raw text for string-typed
|
||||
# fields and parsed structures for list/dict-typed fields.
|
||||
if input_data:
|
||||
try:
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data,
|
||||
user_id,
|
||||
session,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
@@ -10,11 +10,11 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_workspace_manager,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -218,6 +218,12 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
@@ -380,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -530,7 +536,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
@@ -766,7 +772,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -893,7 +899,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -25,35 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
_on_creds_changed: Callable[[str, str], None] | None = None
|
||||
|
||||
|
||||
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback invoked after any credential is created/updated/deleted.
|
||||
|
||||
The callback receives ``(user_id, provider)`` and should be idempotent.
|
||||
Only one hook can be registered at a time; calling this again replaces the
|
||||
previous hook. Intended to be called once at application startup by the
|
||||
copilot module to bust its token cache without creating an import cycle.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
_on_creds_changed = hook
|
||||
|
||||
|
||||
def _bust_copilot_cache(user_id: str, provider: str) -> None:
|
||||
"""Invoke the registered hook (if any) to bust downstream token caches."""
|
||||
if _on_creds_changed is not None:
|
||||
try:
|
||||
_on_creds_changed(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Credential-change hook failed for user=%s provider=%s",
|
||||
user_id,
|
||||
provider,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
@@ -98,11 +69,7 @@ class IntegrationCredentialsManager:
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
# Bust the copilot token cache so that the next bash_exec picks up the
|
||||
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
|
||||
_bust_copilot_cache(user_id, credentials.provider)
|
||||
return result
|
||||
return await self.store.add_creds(user_id, credentials)
|
||||
|
||||
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
||||
@@ -189,8 +156,6 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Bust copilot cache so the refreshed token is picked up immediately.
|
||||
_bust_copilot_cache(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
@@ -203,17 +168,10 @@ class IntegrationCredentialsManager:
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
async with self._locked(user_id, updated.id):
|
||||
await self.store.update_creds(user_id, updated)
|
||||
# Bust the copilot token cache so the updated credential is picked up immediately.
|
||||
_bust_copilot_cache(user_id, updated.provider)
|
||||
|
||||
async def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
async with self._locked(user_id, credentials_id):
|
||||
# Read inside the lock to avoid TOCTOU — another coroutine could
|
||||
# delete the same credential between the read and the delete.
|
||||
creds = await self.store.get_creds_by_id(user_id, credentials_id)
|
||||
await self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
if creds:
|
||||
_bust_copilot_cache(user_id, creds.provider)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
|
||||
@@ -275,12 +275,13 @@ async def store_media_file(
|
||||
# Process file
|
||||
elif file.startswith("data:"):
|
||||
# Data URI
|
||||
parsed_uri = parse_data_uri(file)
|
||||
if parsed_uri is None:
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"Invalid data URI format. Expected data:<mime>;base64,<data>"
|
||||
)
|
||||
mime_type, b64_content = parsed_uri
|
||||
mime_type = match.group(1).strip().lower()
|
||||
b64_content = match.group(2).strip()
|
||||
|
||||
# Generate filename and decode
|
||||
extension = _extension_from_mime(mime_type)
|
||||
@@ -414,70 +415,13 @@ def get_dir_size(path: Path) -> int:
|
||||
return total
|
||||
|
||||
|
||||
async def resolve_media_content(
|
||||
content: MediaFileType,
|
||||
execution_context: "ExecutionContext",
|
||||
*,
|
||||
return_format: MediaReturnFormat,
|
||||
) -> MediaFileType:
|
||||
"""Resolve a ``MediaFileType`` value if it is a media reference, pass through otherwise.
|
||||
|
||||
Convenience wrapper around :func:`is_media_file_ref` + :func:`store_media_file`.
|
||||
Plain text content (source code, filenames) is returned unchanged. Media
|
||||
references (``data:``, ``workspace://``, ``http(s)://``) are resolved via
|
||||
:func:`store_media_file` using *return_format*.
|
||||
|
||||
Use this when a block field is typed as ``MediaFileType`` but may contain
|
||||
either literal text or a media reference.
|
||||
"""
|
||||
if not content or not is_media_file_ref(content):
|
||||
return content
|
||||
return await store_media_file(
|
||||
content, execution_context, return_format=return_format
|
||||
)
|
||||
|
||||
|
||||
def is_media_file_ref(value: str) -> bool:
|
||||
"""Return True if *value* looks like a ``MediaFileType`` reference.
|
||||
|
||||
Detects data URIs, workspace:// references, and HTTP(S) URLs — the
|
||||
formats accepted by :func:`store_media_file`. Plain text content
|
||||
(e.g. source code, filenames) returns False.
|
||||
|
||||
Known limitation: HTTP(S) URL detection is heuristic. Any string that
|
||||
starts with ``http://`` or ``https://`` is treated as a media URL, even
|
||||
if it appears as a URL inside source-code comments or documentation.
|
||||
Blocks that produce source code or Markdown as output may therefore
|
||||
trigger false positives. Callers that need higher precision should
|
||||
inspect the string further (e.g. verify the URL is reachable or has a
|
||||
media-friendly extension).
|
||||
|
||||
Note: this does *not* match local file paths, which are ambiguous
|
||||
(could be filenames or actual paths). Blocks that need to resolve
|
||||
local paths should check for them separately.
|
||||
"""
|
||||
return value.startswith(("data:", "workspace://", "http://", "https://"))
|
||||
|
||||
|
||||
def parse_data_uri(value: str) -> tuple[str, str] | None:
|
||||
"""Parse a ``data:<mime>;base64,<payload>`` URI.
|
||||
|
||||
Returns ``(mime_type, base64_payload)`` if *value* is a valid data URI,
|
||||
or ``None`` if it is not.
|
||||
"""
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", value, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1).strip().lower(), match.group(2).strip()
|
||||
|
||||
|
||||
def get_mime_type(file: str) -> str:
|
||||
"""
|
||||
Get the MIME type of a file, whether it's a data URI, URL, or local path.
|
||||
"""
|
||||
if file.startswith("data:"):
|
||||
parsed_uri = parse_data_uri(file)
|
||||
return parsed_uri[0] if parsed_uri else "application/octet-stream"
|
||||
match = re.match(r"^data:([^;]+);base64,", file)
|
||||
return match.group(1) if match else "application/octet-stream"
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
parsed_url = urlparse(file)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
"""Parse file content into structured Python objects based on file format.
|
||||
|
||||
Used by the ``@@agptfile:`` expansion system to eagerly parse well-known file
|
||||
formats into native Python types *before* schema-driven coercion runs. This
|
||||
lets blocks with ``Any``-typed inputs receive structured data rather than raw
|
||||
strings, while blocks expecting strings get the value coerced back via
|
||||
``convert()``.
|
||||
|
||||
Supported formats:
|
||||
|
||||
- **JSON** (``.json``) — arrays and objects are promoted; scalars stay as strings
|
||||
- **JSON Lines** (``.jsonl``, ``.ndjson``) — each non-empty line parsed as JSON;
|
||||
when all lines are dicts with the same keys (tabular data), output is
|
||||
``list[list[Any]]`` with a header row, consistent with CSV/Parquet/Excel;
|
||||
otherwise returns a plain ``list`` of parsed values
|
||||
- **CSV** (``.csv``) — ``csv.reader`` → ``list[list[str]]``
|
||||
- **TSV** (``.tsv``) — tab-delimited → ``list[list[str]]``
|
||||
- **YAML** (``.yaml``, ``.yml``) — parsed via PyYAML; containers only
|
||||
- **TOML** (``.toml``) — parsed via stdlib ``tomllib``
|
||||
- **Parquet** (``.parquet``) — via pandas/pyarrow → ``list[list[Any]]`` with header row
|
||||
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
|
||||
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
|
||||
|
||||
The **fallback contract** is enforced by :func:`parse_file_content`, not by
|
||||
individual parser functions. If any parser raises, ``parse_file_content``
|
||||
catches the exception and returns the original content unchanged (string for
|
||||
text formats, bytes for binary formats). Callers should never see an
|
||||
exception from the public API when ``strict=False``.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tomllib
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
|
||||
# posixpath.splitext handles forward-slash URI paths correctly on all platforms,
|
||||
# unlike os.path.splitext which uses platform-native separators.
|
||||
from posixpath import splitext
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extension / MIME → format label mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXT_TO_FORMAT: dict[str, str] = {
|
||||
".json": "json",
|
||||
".jsonl": "jsonl",
|
||||
".ndjson": "jsonl",
|
||||
".csv": "csv",
|
||||
".tsv": "tsv",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".toml": "toml",
|
||||
".parquet": "parquet",
|
||||
".xlsx": "xlsx",
|
||||
}
|
||||
|
||||
MIME_TO_FORMAT: dict[str, str] = {
|
||||
"application/json": "json",
|
||||
"application/x-ndjson": "jsonl",
|
||||
"application/jsonl": "jsonl",
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "tsv",
|
||||
"application/x-yaml": "yaml",
|
||||
"application/yaml": "yaml",
|
||||
"text/yaml": "yaml",
|
||||
"application/toml": "toml",
|
||||
"application/vnd.apache.parquet": "parquet",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
||||
}
|
||||
|
||||
# Formats that require raw bytes rather than decoded text.
|
||||
BINARY_FORMATS: frozenset[str] = frozenset({"parquet", "xlsx"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_format_from_uri(uri: str) -> str | None:
|
||||
"""Return a format label based on URI extension or MIME fragment.
|
||||
|
||||
Returns ``None`` when the format cannot be determined — the caller should
|
||||
fall back to returning the content as a plain string.
|
||||
"""
|
||||
# 1. Check MIME fragment (workspace://abc123#application/json)
|
||||
if "#" in uri:
|
||||
_, fragment = uri.rsplit("#", 1)
|
||||
fmt = MIME_TO_FORMAT.get(fragment.lower())
|
||||
if fmt:
|
||||
return fmt
|
||||
|
||||
# 2. Check file extension from the path portion.
|
||||
# Strip the fragment first so ".json#mime" doesn't confuse splitext.
|
||||
path = uri.split("#")[0].split("?")[0]
|
||||
_, ext = splitext(path)
|
||||
fmt = _EXT_TO_FORMAT.get(ext.lower())
|
||||
if fmt is not None:
|
||||
return fmt
|
||||
|
||||
# Legacy .xls is not supported — map it so callers can produce a
|
||||
# user-friendly error instead of returning garbled binary.
|
||||
if ext.lower() == ".xls":
|
||||
return "xls"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False) -> Any:
|
||||
"""Parse *content* according to *fmt* and return a native Python value.
|
||||
|
||||
When *strict* is ``False`` (default), returns the original *content*
|
||||
unchanged if *fmt* is not recognised or parsing fails for any reason.
|
||||
This mode **never raises**.
|
||||
|
||||
When *strict* is ``True``, parsing errors are propagated to the caller.
|
||||
Unrecognised formats or type mismatches (e.g. text for a binary format)
|
||||
still return *content* unchanged without raising.
|
||||
"""
|
||||
if fmt == "xls":
|
||||
return (
|
||||
"[Unsupported format] Legacy .xls files are not supported. "
|
||||
"Please re-save the file as .xlsx (Excel 2007+) and upload again."
|
||||
)
|
||||
|
||||
try:
|
||||
if fmt in BINARY_FORMATS:
|
||||
parser = _BINARY_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
# Caller gave us text for a binary format — can't parse.
|
||||
return content
|
||||
return parser(content)
|
||||
|
||||
parser = _TEXT_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
return parser(content)
|
||||
|
||||
except PARSE_EXCEPTIONS:
|
||||
if strict:
|
||||
raise
|
||||
logger.debug("Structured parsing failed for format=%s, falling back", fmt)
|
||||
return content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_openpyxl_exception() -> type[Exception]:
|
||||
"""Return openpyxl's InvalidFileException, raising ImportError if absent."""
|
||||
from openpyxl.utils.exceptions import InvalidFileException # noqa: PLC0415
|
||||
|
||||
return InvalidFileException
|
||||
|
||||
|
||||
def _load_arrow_exception() -> type[Exception]:
|
||||
"""Return pyarrow's ArrowException, raising ImportError if absent."""
|
||||
from pyarrow import ArrowException # noqa: PLC0415
|
||||
|
||||
return ArrowException
|
||||
|
||||
|
||||
def _optional_exc(loader: "Callable[[], type[Exception]]") -> "type[Exception] | None":
|
||||
"""Return the exception class from *loader*, or ``None`` if the dep is absent."""
|
||||
try:
|
||||
return loader()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
# Exception types that can be raised during file content parsing.
|
||||
# Shared between ``parse_file_content`` (which catches them in non-strict mode)
|
||||
# and ``file_ref._expand_bare_ref`` (which re-raises them as FileRefExpansionError).
|
||||
#
|
||||
# Optional-dependency exception types are loaded via a helper that raises
|
||||
# ``ImportError`` at *parse time* rather than silently becoming ``None`` here.
|
||||
# This ensures mypy sees clean types and missing deps surface as real errors.
|
||||
PARSE_EXCEPTIONS: tuple[type[BaseException], ...] = tuple(
|
||||
exc
|
||||
for exc in (
|
||||
json.JSONDecodeError,
|
||||
csv.Error,
|
||||
yaml.YAMLError,
|
||||
tomllib.TOMLDecodeError,
|
||||
ValueError,
|
||||
UnicodeDecodeError,
|
||||
ImportError,
|
||||
OSError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
zipfile.BadZipFile,
|
||||
_optional_exc(_load_openpyxl_exception),
|
||||
# ArrowException covers ArrowIOError and ArrowCapacityError which
|
||||
# do not inherit from standard exceptions; ArrowInvalid/ArrowTypeError
|
||||
# already map to ValueError/TypeError but this catches the rest.
|
||||
_optional_exc(_load_arrow_exception),
|
||||
)
|
||||
if exc is not None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text-based parsers (content: str → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_container(parser: Callable[[str], Any], content: str) -> list | dict | str:
|
||||
"""Parse *content* and return the result only if it is a container (list/dict).
|
||||
|
||||
Scalar values (strings, numbers, booleans, None) are discarded and the
|
||||
original *content* string is returned instead. This prevents e.g. a JSON
|
||||
file containing just ``"42"`` from silently becoming an int.
|
||||
"""
|
||||
parsed = parser(content)
|
||||
if isinstance(parsed, (list, dict)):
|
||||
return parsed
|
||||
return content
|
||||
|
||||
|
||||
def _parse_json(content: str) -> list | dict | str:
|
||||
return _parse_container(json.loads, content)
|
||||
|
||||
|
||||
def _parse_jsonl(content: str) -> Any:
|
||||
lines = [json.loads(line) for line in content.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return content
|
||||
|
||||
# When every line is a dict with the same keys, convert to table format
|
||||
# (header row + data rows) — consistent with CSV/TSV/Parquet/Excel output.
|
||||
# Require ≥2 dicts so a single-line JSONL stays as [dict] (not a table).
|
||||
if len(lines) >= 2 and all(isinstance(obj, dict) for obj in lines):
|
||||
keys = list(lines[0].keys())
|
||||
# Cache as tuple to avoid O(n×k) list allocations in the all() call.
|
||||
keys_tuple = tuple(keys)
|
||||
if keys and all(tuple(obj.keys()) == keys_tuple for obj in lines[1:]):
|
||||
return [keys] + [[obj[k] for k in keys] for obj in lines]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _parse_csv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter=",")
|
||||
|
||||
|
||||
def _parse_tsv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter="\t")
|
||||
|
||||
|
||||
def _parse_delimited(content: str, *, delimiter: str) -> Any:
|
||||
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
|
||||
# csv.reader never yields [] — blank lines yield [""]. Filter out
|
||||
# rows where every cell is empty (i.e. truly blank lines).
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
if not rows:
|
||||
return content
|
||||
# If the declared delimiter produces only single-column rows, try
|
||||
# sniffing the actual delimiter — catches misidentified files (e.g.
|
||||
# a tab-delimited file with a .csv extension).
|
||||
if len(rows[0]) == 1:
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(content[:8192])
|
||||
if dialect.delimiter != delimiter:
|
||||
reader = csv.reader(io.StringIO(content), dialect)
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
except csv.Error:
|
||||
pass
|
||||
if rows and len(rows[0]) >= 2:
|
||||
return rows
|
||||
return content
|
||||
|
||||
|
||||
def _row_has_content(row: list[str]) -> bool:
|
||||
"""Return True when *row* contains at least one non-empty cell.
|
||||
|
||||
``csv.reader`` never yields ``[]`` — truly blank lines yield ``[""]``.
|
||||
This predicate filters those out consistently across the initial read
|
||||
and the sniffer-fallback re-read.
|
||||
"""
|
||||
return any(cell for cell in row)
|
||||
|
||||
|
||||
def _parse_yaml(content: str) -> list | dict | str:
|
||||
# NOTE: YAML anchor/alias expansion can amplify input beyond the 10MB cap.
|
||||
# safe_load prevents code execution; for production hardening consider
|
||||
# a YAML parser with expansion limits (e.g. ruamel.yaml with max_alias_count).
|
||||
if "\n---" in content or content.startswith("---\n"):
|
||||
# Multi-document YAML: only the first document is parsed; the rest
|
||||
# are silently ignored by yaml.safe_load. Warn so callers are aware.
|
||||
logger.warning(
|
||||
"Multi-document YAML detected (--- separator); "
|
||||
"only the first document will be parsed."
|
||||
)
|
||||
return _parse_container(yaml.safe_load, content)
|
||||
|
||||
|
||||
def _parse_toml(content: str) -> Any:
|
||||
parsed = tomllib.loads(content)
|
||||
# tomllib.loads always returns a dict — return it even if empty.
|
||||
return parsed
|
||||
|
||||
|
||||
_TEXT_PARSERS: dict[str, Callable[[str], Any]] = {
|
||||
"json": _parse_json,
|
||||
"jsonl": _parse_jsonl,
|
||||
"csv": _parse_csv,
|
||||
"tsv": _parse_tsv,
|
||||
"yaml": _parse_yaml,
|
||||
"toml": _parse_toml,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary-based parsers (content: bytes → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_parquet(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_parquet(io.BytesIO(content))
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _parse_xlsx(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
# Explicitly specify openpyxl engine; the default engine varies by pandas
|
||||
# version and does not support legacy .xls (which is excluded by our format map).
|
||||
df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _df_to_rows(df: Any) -> list[list[Any]]:
|
||||
"""Convert a DataFrame to ``list[list[Any]]`` with a header row.
|
||||
|
||||
NaN values are replaced with ``None`` so the result is JSON-serializable.
|
||||
Uses explicit cell-level checking because ``df.where(df.notna(), None)``
|
||||
silently converts ``None`` back to ``NaN`` in float64 columns.
|
||||
"""
|
||||
header = df.columns.tolist()
|
||||
rows = [
|
||||
[None if _is_nan(cell) else cell for cell in row] for row in df.values.tolist()
|
||||
]
|
||||
return [header] + rows
|
||||
|
||||
|
||||
def _is_nan(cell: Any) -> bool:
|
||||
"""Check if a cell value is NaN, handling non-scalar types (lists, dicts).
|
||||
|
||||
``pd.isna()`` on a list/dict returns a boolean array which raises
|
||||
``ValueError`` in a boolean context. Guard with a scalar check first.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
return bool(pd.api.types.is_scalar(cell) and pd.isna(cell))
|
||||
|
||||
|
||||
_BINARY_PARSERS: dict[str, Callable[[bytes], Any]] = {
|
||||
"parquet": _parse_parquet,
|
||||
"xlsx": _parse_xlsx,
|
||||
}
|
||||
@@ -1,624 +0,0 @@
|
||||
"""Tests for file_content_parser — format inference and structured parsing."""
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# infer_format_from_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferFormat:
|
||||
# --- extension-based ---
|
||||
|
||||
def test_json_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.json") == "json"
|
||||
|
||||
def test_jsonl_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.jsonl") == "jsonl"
|
||||
|
||||
def test_ndjson_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.ndjson") == "jsonl"
|
||||
|
||||
def test_csv_extension(self):
|
||||
assert infer_format_from_uri("workspace:///reports/sales.csv") == "csv"
|
||||
|
||||
def test_tsv_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.tsv") == "tsv"
|
||||
|
||||
def test_yaml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yaml") == "yaml"
|
||||
|
||||
def test_yml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yml") == "yaml"
|
||||
|
||||
def test_toml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.toml") == "toml"
|
||||
|
||||
def test_parquet_extension(self):
|
||||
assert infer_format_from_uri("/data/table.parquet") == "parquet"
|
||||
|
||||
def test_xlsx_extension(self):
|
||||
assert infer_format_from_uri("/data/spreadsheet.xlsx") == "xlsx"
|
||||
|
||||
def test_xls_extension_returns_xls_label(self):
|
||||
# Legacy .xls is mapped so callers can produce a helpful error.
|
||||
assert infer_format_from_uri("/data/old_spreadsheet.xls") == "xls"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert infer_format_from_uri("/data/FILE.JSON") == "json"
|
||||
assert infer_format_from_uri("/data/FILE.CSV") == "csv"
|
||||
|
||||
def test_unicode_filename(self):
|
||||
assert infer_format_from_uri("/home/user/\u30c7\u30fc\u30bf.json") == "json"
|
||||
assert infer_format_from_uri("/home/user/\u00e9t\u00e9.csv") == "csv"
|
||||
|
||||
def test_unknown_extension(self):
|
||||
assert infer_format_from_uri("/home/user/readme.txt") is None
|
||||
|
||||
def test_no_extension(self):
|
||||
assert infer_format_from_uri("workspace://abc123") is None
|
||||
|
||||
# --- MIME-based ---
|
||||
|
||||
def test_mime_json(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/json") == "json"
|
||||
|
||||
def test_mime_csv(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/csv") == "csv"
|
||||
|
||||
def test_mime_tsv(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#text/tab-separated-values")
|
||||
== "tsv"
|
||||
)
|
||||
|
||||
def test_mime_ndjson(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/x-ndjson") == "jsonl"
|
||||
)
|
||||
|
||||
def test_mime_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/x-yaml") == "yaml"
|
||||
|
||||
def test_mime_xlsx(self):
|
||||
uri = "workspace://abc123#application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
assert infer_format_from_uri(uri) == "xlsx"
|
||||
|
||||
def test_mime_parquet(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/vnd.apache.parquet")
|
||||
== "parquet"
|
||||
)
|
||||
|
||||
def test_unknown_mime(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/plain") is None
|
||||
|
||||
def test_unknown_mime_falls_through_to_extension(self):
|
||||
# Unknown MIME (text/plain) should fall through to extension-based detection.
|
||||
assert infer_format_from_uri("workspace:///data.csv#text/plain") == "csv"
|
||||
|
||||
# --- MIME takes precedence over extension ---
|
||||
|
||||
def test_mime_overrides_extension(self):
|
||||
# .txt extension but JSON MIME → json
|
||||
assert infer_format_from_uri("workspace:///file.txt#application/json") == "json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJson:
|
||||
def test_array(self):
|
||||
result = parse_file_content("[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_object(self):
|
||||
result = parse_file_content('{"key": "value"}', "json")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_nested(self):
|
||||
content = json.dumps({"rows": [[1, 2], [3, 4]]})
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == {"rows": [[1, 2], [3, 4]]}
|
||||
|
||||
def test_scalar_string_stays_as_string(self):
|
||||
result = parse_file_content('"hello"', "json")
|
||||
assert result == '"hello"' # original content, not parsed
|
||||
|
||||
def test_scalar_number_stays_as_string(self):
|
||||
result = parse_file_content("42", "json")
|
||||
assert result == "42"
|
||||
|
||||
def test_scalar_boolean_stays_as_string(self):
|
||||
result = parse_file_content("true", "json")
|
||||
assert result == "true"
|
||||
|
||||
def test_null_stays_as_string(self):
|
||||
result = parse_file_content("null", "json")
|
||||
assert result == "null"
|
||||
|
||||
def test_invalid_json_fallback(self):
|
||||
content = "not json at all"
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == content
|
||||
|
||||
def test_empty_string_fallback(self):
|
||||
result = parse_file_content("", "json")
|
||||
assert result == ""
|
||||
|
||||
def test_bytes_input_decoded(self):
|
||||
result = parse_file_content(b"[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSONL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJsonl:
|
||||
def test_tabular_uniform_dicts_to_table_format(self):
|
||||
"""JSONL with uniform dict keys → table format (header + rows),
|
||||
consistent with CSV/TSV/Parquet/Excel output."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":"yellow"}\n{"name":"cherry","color":"red"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", "yellow"],
|
||||
["cherry", "red"],
|
||||
]
|
||||
|
||||
def test_tabular_single_key_dicts(self):
|
||||
"""JSONL with single-key uniform dicts → table format."""
|
||||
content = '{"a": 1}\n{"a": 2}\n{"a": 3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2], [3]]
|
||||
|
||||
def test_tabular_blank_lines_skipped(self):
|
||||
content = '{"a": 1}\n\n{"a": 2}\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2]]
|
||||
|
||||
def test_heterogeneous_dicts_stay_as_list(self):
|
||||
"""JSONL with different keys across objects → list of dicts (no table)."""
|
||||
content = '{"name":"apple"}\n{"color":"red"}\n{"size":3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"name": "apple"}, {"color": "red"}, {"size": 3}]
|
||||
|
||||
def test_partially_overlapping_keys_stay_as_list(self):
|
||||
"""JSONL dicts with partially overlapping keys → list of dicts."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","size":"medium"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
{"name": "apple", "color": "red"},
|
||||
{"name": "banana", "size": "medium"},
|
||||
]
|
||||
|
||||
def test_mixed_types_stay_as_list(self):
|
||||
"""JSONL with non-dict lines → list of parsed values (no table)."""
|
||||
content = '1\n"hello"\n[1,2]\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [1, "hello", [1, 2]]
|
||||
|
||||
def test_mixed_dicts_and_non_dicts_stay_as_list(self):
|
||||
"""JSONL mixing dicts and non-dicts → list of parsed values."""
|
||||
content = '{"a": 1}\n42\n{"b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1}, 42, {"b": 2}]
|
||||
|
||||
def test_tabular_preserves_key_order(self):
|
||||
"""Table header should follow the key order of the first object."""
|
||||
content = '{"z": 1, "a": 2}\n{"z": 3, "a": 4}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result[0] == ["z", "a"] # order from first object
|
||||
assert result[1] == [1, 2]
|
||||
assert result[2] == [3, 4]
|
||||
|
||||
def test_single_dict_stays_as_list(self):
|
||||
"""Single-line JSONL with one dict → [dict], NOT a table.
|
||||
Tabular detection requires ≥2 dicts to avoid vacuously true all()."""
|
||||
content = '{"a": 1, "b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1, "b": 2}]
|
||||
|
||||
def test_tabular_with_none_values(self):
|
||||
"""Uniform keys but some null values → table with None cells."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":null}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", None],
|
||||
]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "jsonl")
|
||||
assert result == ""
|
||||
|
||||
def test_all_blank_lines_fallback(self):
|
||||
result = parse_file_content("\n\n\n", "jsonl")
|
||||
assert result == "\n\n\n"
|
||||
|
||||
def test_invalid_line_fallback(self):
|
||||
content = '{"a": 1}\nnot json\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == content # fallback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseCsv:
|
||||
def test_basic(self):
|
||||
content = "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_quoted_fields(self):
|
||||
content = 'Name,Bio\nAlice,"Loves, commas"\nBob,Simple'
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result[1] == ["Alice", "Loves, commas"]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
# Only 1 column — not tabular enough.
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
def test_empty_rows_skipped(self):
|
||||
content = "A,B\n\n1,2\n\n3,4"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["A", "B"], ["1", "2"], ["3", "4"]]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "csv")
|
||||
assert result == ""
|
||||
|
||||
def test_utf8_bom(self):
|
||||
"""CSV with a UTF-8 BOM should parse correctly (BOM stripped by decode)."""
|
||||
bom = "\ufeff"
|
||||
content = bom + "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
# The BOM may be part of the first header cell; ensure rows are still parsed.
|
||||
assert len(result) == 3
|
||||
assert result[1] == ["Alice", "90"]
|
||||
assert result[2] == ["Bob", "85"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseTsv:
|
||||
def test_basic(self):
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — YAML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseYaml:
|
||||
def test_list(self):
|
||||
content = "- apple\n- banana\n- cherry"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_dict(self):
|
||||
content = "name: Alice\nage: 30"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"name": "Alice", "age": 30}
|
||||
|
||||
def test_nested(self):
|
||||
content = "users:\n - name: Alice\n - name: Bob"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"users": [{"name": "Alice"}, {"name": "Bob"}]}
|
||||
|
||||
def test_scalar_stays_as_string(self):
|
||||
result = parse_file_content("hello world", "yaml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_invalid_yaml_fallback(self):
|
||||
content = ":\n :\n invalid: - -"
|
||||
result = parse_file_content(content, "yaml")
|
||||
# Malformed YAML should fall back to the original string, not raise.
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TOML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseToml:
|
||||
def test_basic(self):
|
||||
content = '[server]\nhost = "localhost"\nport = 8080'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"server": {"host": "localhost", "port": 8080}}
|
||||
|
||||
def test_flat(self):
|
||||
content = 'name = "test"\ncount = 42'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"name": "test", "count": 42}
|
||||
|
||||
def test_empty_string_returns_empty_dict(self):
|
||||
result = parse_file_content("", "toml")
|
||||
assert result == {}
|
||||
|
||||
def test_invalid_toml_fallback(self):
|
||||
result = parse_file_content("not = [valid toml", "toml")
|
||||
assert result == "not = [valid toml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Parquet (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
import pyarrow as _pa # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
|
||||
_has_pyarrow = True
|
||||
except ImportError:
|
||||
_has_pyarrow = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestParseParquet:
|
||||
@pytest.fixture
|
||||
def parquet_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, parquet_bytes: bytes):
|
||||
result = parse_file_content(parquet_bytes, "parquet")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
# Parquet is binary — string input can't be parsed.
|
||||
result = parse_file_content("not parquet", "parquet")
|
||||
assert result == "not parquet"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not parquet bytes", "parquet")
|
||||
assert result == b"not parquet bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "parquet")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in Parquet must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", None, "z"]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
# Row with NaN in float col → None
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[2][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]:
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Excel (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseExcel:
|
||||
@pytest.fixture
|
||||
def xlsx_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type] # BytesIO is a valid target
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, xlsx_bytes: bytes):
|
||||
result = parse_file_content(xlsx_bytes, "xlsx")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
result = parse_file_content("not xlsx", "xlsx")
|
||||
assert result == "not xlsx"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "xlsx")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in float columns must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", "y", None]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type]
|
||||
result = parse_file_content(buf.getvalue(), "xlsx")
|
||||
# Row with NaN in float col → None, not float('nan')
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[3][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]: # skip header
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — unknown format / fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallback:
|
||||
def test_unknown_format_returns_content(self):
|
||||
result = parse_file_content("hello world", "xml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_none_format_returns_content(self):
|
||||
# Shouldn't normally be called with unrecognised format, but must not crash.
|
||||
result = parse_file_content("hello", "unknown_format")
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BINARY_FORMATS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBinaryFormats:
|
||||
def test_parquet_is_binary(self):
|
||||
assert "parquet" in BINARY_FORMATS
|
||||
|
||||
def test_xlsx_is_binary(self):
|
||||
assert "xlsx" in BINARY_FORMATS
|
||||
|
||||
def test_text_formats_not_binary(self):
|
||||
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
|
||||
assert fmt not in BINARY_FORMATS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MIME mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMimeMapping:
|
||||
def test_application_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/yaml") == "yaml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSV sniffer fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCsvSnifferFallback:
|
||||
def test_tab_delimited_with_csv_format(self):
|
||||
"""Tab-delimited content parsed as csv should use sniffer fallback."""
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_sniffer_failure_returns_content(self):
|
||||
"""When sniffer fails, single-column falls back to raw content."""
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenpyxlInvalidFile fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenpyxlFallback:
|
||||
def test_invalid_xlsx_non_strict(self):
|
||||
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Header-only CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeaderOnlyCsv:
|
||||
def test_header_only_csv_returns_header_row(self):
|
||||
"""CSV with only a header row (no data rows) should return [[header]]."""
|
||||
content = "Name,Score"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
def test_header_only_csv_with_trailing_newline(self):
|
||||
content = "Name,Score\n"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary format + line range (line range ignored for binary formats)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestBinaryFormatLineRange:
|
||||
def test_parquet_ignores_line_range(self):
|
||||
"""Binary formats should parse the full file regardless of line range.
|
||||
|
||||
Line ranges are meaningless for binary formats (parquet/xlsx) — the
|
||||
caller (file_ref._expand_bare_ref) passes raw bytes and the parser
|
||||
should return the complete structured data.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
# parse_file_content itself doesn't take a line range — this tests
|
||||
# that the full content is parsed even though the bytes could have
|
||||
# been truncated upstream (it's not, by design).
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
assert result == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy .xls UX
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestXlsFallback:
|
||||
def test_xls_returns_helpful_error_string(self):
|
||||
"""Uploading a .xls file should produce a helpful error, not garbled binary."""
|
||||
result = parse_file_content(b"\xd0\xcf\x11\xe0garbled", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
assert "not supported" in result.lower()
|
||||
|
||||
def test_xls_with_string_content(self):
|
||||
result = parse_file_content("some text", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
@@ -8,12 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import (
|
||||
is_media_file_ref,
|
||||
parse_data_uri,
|
||||
resolve_media_content,
|
||||
store_media_file,
|
||||
)
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
@@ -349,162 +344,3 @@ class TestFileCloudIntegration:
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_media_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsMediaFileRef:
|
||||
def test_data_uri(self):
|
||||
assert is_media_file_ref("data:image/png;base64,iVBORw0KGg==") is True
|
||||
|
||||
def test_workspace_uri(self):
|
||||
assert is_media_file_ref("workspace://abc123") is True
|
||||
|
||||
def test_workspace_uri_with_mime(self):
|
||||
assert is_media_file_ref("workspace://abc123#image/png") is True
|
||||
|
||||
def test_http_url(self):
|
||||
assert is_media_file_ref("http://example.com/image.png") is True
|
||||
|
||||
def test_https_url(self):
|
||||
assert is_media_file_ref("https://example.com/image.png") is True
|
||||
|
||||
def test_plain_text(self):
|
||||
assert is_media_file_ref("print('hello')") is False
|
||||
|
||||
def test_local_path(self):
|
||||
assert is_media_file_ref("/tmp/file.txt") is False
|
||||
|
||||
def test_empty_string(self):
|
||||
assert is_media_file_ref("") is False
|
||||
|
||||
def test_filename(self):
|
||||
assert is_media_file_ref("image.png") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_data_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseDataUri:
|
||||
def test_valid_png(self):
|
||||
result = parse_data_uri("data:image/png;base64,iVBORw0KGg==")
|
||||
assert result is not None
|
||||
mime, payload = result
|
||||
assert mime == "image/png"
|
||||
assert payload == "iVBORw0KGg=="
|
||||
|
||||
def test_valid_text(self):
|
||||
result = parse_data_uri("data:text/plain;base64,SGVsbG8=")
|
||||
assert result is not None
|
||||
assert result[0] == "text/plain"
|
||||
assert result[1] == "SGVsbG8="
|
||||
|
||||
def test_mime_case_normalized(self):
|
||||
result = parse_data_uri("data:IMAGE/PNG;base64,abc")
|
||||
assert result is not None
|
||||
assert result[0] == "image/png"
|
||||
|
||||
def test_not_data_uri(self):
|
||||
assert parse_data_uri("workspace://abc123") is None
|
||||
|
||||
def test_plain_text(self):
|
||||
assert parse_data_uri("hello world") is None
|
||||
|
||||
def test_missing_base64(self):
|
||||
assert parse_data_uri("data:image/png;utf-8,abc") is None
|
||||
|
||||
def test_empty_payload(self):
|
||||
result = parse_data_uri("data:image/png;base64,")
|
||||
assert result is not None
|
||||
assert result[1] == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_media_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveMediaContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text_passthrough(self):
|
||||
"""Plain text content (not a media ref) passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("print('hello')"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_passthrough(self):
|
||||
"""Empty string passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(""),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_media_ref_delegates_to_store(self):
|
||||
"""Media references are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_uri_delegates_to_store(self):
|
||||
"""Data URIs are also resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
data_uri = "data:image/png;base64,iVBORw0KGg=="
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType(data_uri)),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(data_uri),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == data_uri
|
||||
mock_store.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_https_url_delegates_to_store(self):
|
||||
"""HTTPS URLs are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
89
autogpt_platform/backend/poetry.lock
generated
89
autogpt_platform/backend/poetry.lock
generated
@@ -1360,18 +1360,6 @@ files = [
|
||||
dnspython = ">=2.0.0"
|
||||
idna = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "et-xmlfile"
|
||||
version = "2.0.0"
|
||||
description = "An implementation of lxml.xmlfile for the standard library"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"},
|
||||
{file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exa-py"
|
||||
version = "1.16.1"
|
||||
@@ -4240,21 +4228,6 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
realtime = ["websockets (>=13,<16)"]
|
||||
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "openpyxl"
|
||||
version = "3.1.5"
|
||||
description = "A Python library to read/write Excel 2010 xlsx/xlsm files"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"},
|
||||
{file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
et-xmlfile = "*"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.39.1"
|
||||
@@ -5457,66 +5430,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "23.0.1"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3fab8f82571844eb3c460f90a75583801d14ca0cc32b1acc8c361650e006fd56"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:3f91c038b95f71ddfc865f11d5876c42f343b4495535bd262c7b321b0b94507c"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d0744403adabef53c985a7f8a082b502a368510c40d184df349a0a8754533258"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c33b5bf406284fd0bba436ed6f6c3ebe8e311722b441d89397c54f871c6863a2"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ddf743e82f69dcd6dbbcb63628895d7161e04e56794ef80550ac6f3315eeb1d5"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e052a211c5ac9848ae15d5ec875ed0943c0221e2fcfe69eee80b604b4e703222"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5abde149bb3ce524782d838eb67ac095cd3fd6090eba051130589793f1a7f76d"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce"},
|
||||
{file = "pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.2"
|
||||
@@ -8969,4 +8882,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "86dab25684dd46e635a33bd33281a926e5626a874ecc048c34389fecf34a87d8"
|
||||
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
|
||||
|
||||
@@ -92,8 +92,6 @@ gravitas-md2gdocs = "^0.1.0"
|
||||
posthog = "^7.6.0"
|
||||
fpdf2 = "^2.8.6"
|
||||
langsmith = "^0.7.7"
|
||||
openpyxl = "^3.1.5"
|
||||
pyarrow = "^23.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
@@ -0,0 +1,440 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { screen, cleanup } from "@testing-library/react";
|
||||
import { render } from "@/tests/integrations/test-utils";
|
||||
import React from "react";
|
||||
import { BlockUIType } from "../components/types";
|
||||
import type {
|
||||
CustomNodeData,
|
||||
CustomNode as CustomNodeType,
|
||||
} from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import type { NodeProps } from "@xyflow/react";
|
||||
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
|
||||
// ---- Mock sub-components ----
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeContainer",
|
||||
() => ({
|
||||
NodeContainer: ({
|
||||
children,
|
||||
hasErrors,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
hasErrors: boolean;
|
||||
}) => (
|
||||
<div data-testid="node-container" data-has-errors={String(!!hasErrors)}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader",
|
||||
() => ({
|
||||
NodeHeader: ({ data }: { data: CustomNodeData }) => (
|
||||
<div data-testid="node-header">{data.title}</div>
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/StickyNoteBlock",
|
||||
() => ({
|
||||
StickyNoteBlock: ({ data }: { data: CustomNodeData }) => (
|
||||
<div data-testid="sticky-note-block">{data.title}</div>
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeAdvancedToggle",
|
||||
() => ({
|
||||
NodeAdvancedToggle: () => <div data-testid="node-advanced-toggle" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/NodeOutput",
|
||||
() => ({
|
||||
NodeDataRenderer: () => <div data-testid="node-data-renderer" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeExecutionBadge",
|
||||
() => ({
|
||||
NodeExecutionBadge: () => <div data-testid="node-execution-badge" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeRightClickMenu",
|
||||
() => ({
|
||||
NodeRightClickMenu: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="node-right-click-menu">{children}</div>
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/WebhookDisclaimer",
|
||||
() => ({
|
||||
WebhookDisclaimer: () => <div data-testid="webhook-disclaimer" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/SubAgentUpdate/SubAgentUpdateFeature",
|
||||
() => ({
|
||||
SubAgentUpdateFeature: () => <div data-testid="sub-agent-update" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/AyrshareConnectButton",
|
||||
() => ({
|
||||
AyrshareConnectButton: () => <div data-testid="ayrshare-connect-button" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/FormCreator",
|
||||
() => ({
|
||||
FormCreator: () => <div data-testid="form-creator" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/OutputHandler",
|
||||
() => ({
|
||||
OutputHandler: () => <div data-testid="output-handler" />,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/components/renderers/InputRenderer/utils/input-schema-pre-processor",
|
||||
() => ({
|
||||
preprocessInputSchema: (schema: unknown) => schema,
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode",
|
||||
() => ({
|
||||
useCustomNode: ({ data }: { data: CustomNodeData }) => ({
|
||||
inputSchema: data.inputSchema,
|
||||
outputSchema: data.outputSchema,
|
||||
isMCPWithTool: false,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock("@xyflow/react", async () => {
|
||||
const actual = await vi.importActual("@xyflow/react");
|
||||
return {
|
||||
...actual,
|
||||
useReactFlow: () => ({
|
||||
getNodes: () => [],
|
||||
getEdges: () => [],
|
||||
setNodes: vi.fn(),
|
||||
setEdges: vi.fn(),
|
||||
getNode: vi.fn(),
|
||||
}),
|
||||
useNodeId: () => "test-node-id",
|
||||
useUpdateNodeInternals: () => vi.fn(),
|
||||
Handle: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
Position: { Left: "left", Right: "right", Top: "top", Bottom: "bottom" },
|
||||
};
|
||||
});
|
||||
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
|
||||
// ---- Helpers ----
|
||||
|
||||
function buildNodeData(
|
||||
overrides: Partial<CustomNodeData> = {},
|
||||
): CustomNodeData {
|
||||
return {
|
||||
hardcodedValues: {},
|
||||
title: "Test Block",
|
||||
description: "A test block",
|
||||
inputSchema: { type: "object", properties: {} },
|
||||
outputSchema: { type: "object", properties: {} },
|
||||
uiType: BlockUIType.STANDARD,
|
||||
block_id: "block-123",
|
||||
costs: [],
|
||||
categories: [],
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function buildNodeProps(
|
||||
dataOverrides: Partial<CustomNodeData> = {},
|
||||
propsOverrides: Partial<NodeProps<CustomNodeType>> = {},
|
||||
): NodeProps<CustomNodeType> {
|
||||
return {
|
||||
id: "node-1",
|
||||
data: buildNodeData(dataOverrides),
|
||||
selected: false,
|
||||
type: "custom",
|
||||
isConnectable: true,
|
||||
positionAbsoluteX: 0,
|
||||
positionAbsoluteY: 0,
|
||||
zIndex: 0,
|
||||
dragging: false,
|
||||
dragHandle: undefined,
|
||||
draggable: true,
|
||||
selectable: true,
|
||||
deletable: true,
|
||||
parentId: undefined,
|
||||
width: undefined,
|
||||
height: undefined,
|
||||
sourcePosition: undefined,
|
||||
targetPosition: undefined,
|
||||
...propsOverrides,
|
||||
};
|
||||
}
|
||||
|
||||
function renderCustomNode(
|
||||
dataOverrides: Partial<CustomNodeData> = {},
|
||||
propsOverrides: Partial<NodeProps<CustomNodeType>> = {},
|
||||
) {
|
||||
const props = buildNodeProps(dataOverrides, propsOverrides);
|
||||
return render(<CustomNode {...props} />);
|
||||
}
|
||||
|
||||
function createExecutionResult(
|
||||
overrides: Partial<NodeExecutionResult> = {},
|
||||
): NodeExecutionResult {
|
||||
return {
|
||||
node_exec_id: overrides.node_exec_id ?? "exec-1",
|
||||
node_id: overrides.node_id ?? "node-1",
|
||||
graph_exec_id: overrides.graph_exec_id ?? "graph-exec-1",
|
||||
graph_id: overrides.graph_id ?? "graph-1",
|
||||
graph_version: overrides.graph_version ?? 1,
|
||||
user_id: overrides.user_id ?? "test-user",
|
||||
block_id: overrides.block_id ?? "block-1",
|
||||
status: overrides.status ?? "COMPLETED",
|
||||
input_data: overrides.input_data ?? {},
|
||||
output_data: overrides.output_data ?? {},
|
||||
add_time: overrides.add_time ?? new Date("2024-01-01T00:00:00Z"),
|
||||
queue_time: overrides.queue_time ?? new Date("2024-01-01T00:00:00Z"),
|
||||
start_time: overrides.start_time ?? new Date("2024-01-01T00:00:01Z"),
|
||||
end_time: overrides.end_time ?? new Date("2024-01-01T00:00:02Z"),
|
||||
};
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
beforeEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
|
||||
describe("CustomNode", () => {
|
||||
describe("STANDARD type rendering", () => {
|
||||
it("renders NodeHeader with the block title", () => {
|
||||
renderCustomNode({ title: "My Standard Block" });
|
||||
|
||||
const header = screen.getByTestId("node-header");
|
||||
expect(header).toBeDefined();
|
||||
expect(header.textContent).toContain("My Standard Block");
|
||||
});
|
||||
|
||||
it("renders NodeContainer, FormCreator, OutputHandler, and NodeExecutionBadge", () => {
|
||||
renderCustomNode();
|
||||
|
||||
expect(screen.getByTestId("node-container")).toBeDefined();
|
||||
expect(screen.getByTestId("form-creator")).toBeDefined();
|
||||
expect(screen.getByTestId("output-handler")).toBeDefined();
|
||||
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
|
||||
expect(screen.getByTestId("node-data-renderer")).toBeDefined();
|
||||
expect(screen.getByTestId("node-advanced-toggle")).toBeDefined();
|
||||
});
|
||||
|
||||
it("wraps content in NodeRightClickMenu", () => {
|
||||
renderCustomNode();
|
||||
|
||||
expect(screen.getByTestId("node-right-click-menu")).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not render StickyNoteBlock for STANDARD type", () => {
|
||||
renderCustomNode();
|
||||
|
||||
expect(screen.queryByTestId("sticky-note-block")).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("NOTE type rendering", () => {
|
||||
it("renders StickyNoteBlock instead of main UI", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.NOTE, title: "My Note" });
|
||||
|
||||
const note = screen.getByTestId("sticky-note-block");
|
||||
expect(note).toBeDefined();
|
||||
expect(note.textContent).toContain("My Note");
|
||||
});
|
||||
|
||||
it("does not render NodeContainer or other standard components", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.NOTE });
|
||||
|
||||
expect(screen.queryByTestId("node-container")).toBeNull();
|
||||
expect(screen.queryByTestId("node-header")).toBeNull();
|
||||
expect(screen.queryByTestId("form-creator")).toBeNull();
|
||||
expect(screen.queryByTestId("output-handler")).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("WEBHOOK type rendering", () => {
|
||||
it("renders WebhookDisclaimer for WEBHOOK type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.WEBHOOK });
|
||||
|
||||
expect(screen.getByTestId("webhook-disclaimer")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders WebhookDisclaimer for WEBHOOK_MANUAL type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.WEBHOOK_MANUAL });
|
||||
|
||||
expect(screen.getByTestId("webhook-disclaimer")).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("AGENT type rendering", () => {
|
||||
it("renders SubAgentUpdateFeature for AGENT type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.AGENT });
|
||||
|
||||
expect(screen.getByTestId("sub-agent-update")).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not render SubAgentUpdateFeature for non-AGENT types", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.STANDARD });
|
||||
|
||||
expect(screen.queryByTestId("sub-agent-update")).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("OUTPUT type rendering", () => {
|
||||
it("does not render OutputHandler for OUTPUT type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.OUTPUT });
|
||||
|
||||
expect(screen.queryByTestId("output-handler")).toBeNull();
|
||||
});
|
||||
|
||||
it("still renders FormCreator and other components for OUTPUT type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.OUTPUT });
|
||||
|
||||
expect(screen.getByTestId("form-creator")).toBeDefined();
|
||||
expect(screen.getByTestId("node-header")).toBeDefined();
|
||||
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("AYRSHARE type rendering", () => {
|
||||
it("renders AyrshareConnectButton for AYRSHARE type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.AYRSHARE });
|
||||
|
||||
expect(screen.getByTestId("ayrshare-connect-button")).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not render AyrshareConnectButton for non-AYRSHARE types", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.STANDARD });
|
||||
|
||||
expect(screen.queryByTestId("ayrshare-connect-button")).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("error states", () => {
|
||||
it("sets hasErrors on NodeContainer when data.errors has non-empty values", () => {
|
||||
renderCustomNode({
|
||||
errors: { field1: "This field is required" },
|
||||
});
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container.getAttribute("data-has-errors")).toBe("true");
|
||||
});
|
||||
|
||||
it("does not set hasErrors when data.errors is empty", () => {
|
||||
renderCustomNode({ errors: {} });
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container.getAttribute("data-has-errors")).toBe("false");
|
||||
});
|
||||
|
||||
it("does not set hasErrors when data.errors values are all empty strings", () => {
|
||||
renderCustomNode({ errors: { field1: "" } });
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container.getAttribute("data-has-errors")).toBe("false");
|
||||
});
|
||||
|
||||
it("sets hasErrors when last execution result has error in output_data", () => {
|
||||
renderCustomNode({
|
||||
nodeExecutionResults: [
|
||||
createExecutionResult({
|
||||
output_data: { error: ["Something went wrong"] },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container.getAttribute("data-has-errors")).toBe("true");
|
||||
});
|
||||
|
||||
it("does not set hasErrors when execution results have no error", () => {
|
||||
renderCustomNode({
|
||||
nodeExecutionResults: [
|
||||
createExecutionResult({
|
||||
output_data: { result: ["success"] },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container.getAttribute("data-has-errors")).toBe("false");
|
||||
});
|
||||
});
|
||||
|
||||
describe("NodeExecutionBadge", () => {
|
||||
it("always renders NodeExecutionBadge for non-NOTE types", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.STANDARD });
|
||||
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders NodeExecutionBadge for AGENT type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.AGENT });
|
||||
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders NodeExecutionBadge for OUTPUT type", () => {
|
||||
renderCustomNode({ uiType: BlockUIType.OUTPUT });
|
||||
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("edge cases", () => {
|
||||
it("renders without nodeExecutionResults", () => {
|
||||
renderCustomNode({ nodeExecutionResults: undefined });
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container).toBeDefined();
|
||||
expect(container.getAttribute("data-has-errors")).toBe("false");
|
||||
});
|
||||
|
||||
it("renders without errors property", () => {
|
||||
renderCustomNode({ errors: undefined });
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container).toBeDefined();
|
||||
expect(container.getAttribute("data-has-errors")).toBe("false");
|
||||
});
|
||||
|
||||
it("renders with empty execution results array", () => {
|
||||
renderCustomNode({ nodeExecutionResults: [] });
|
||||
|
||||
const container = screen.getByTestId("node-container");
|
||||
expect(container).toBeDefined();
|
||||
expect(container.getAttribute("data-has-errors")).toBe("false");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,172 @@
|
||||
import React from "react";
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import { TooltipProvider } from "@radix-ui/react-tooltip";
|
||||
import { DraftRecoveryPopup } from "../components/DraftRecoveryDialog/DraftRecoveryPopup";
|
||||
|
||||
const mockOnLoad = vi.fn();
|
||||
const mockOnDiscard = vi.fn();
|
||||
|
||||
vi.mock("../components/DraftRecoveryDialog/useDraftRecoveryPopup", () => ({
|
||||
useDraftRecoveryPopup: vi.fn(() => ({
|
||||
isOpen: true,
|
||||
popupRef: { current: null },
|
||||
nodeCount: 3,
|
||||
edgeCount: 2,
|
||||
diff: {
|
||||
nodes: { added: 1, removed: 0, modified: 2 },
|
||||
edges: { added: 1, removed: 1, modified: 0 },
|
||||
},
|
||||
savedAt: Date.now(),
|
||||
onLoad: mockOnLoad,
|
||||
onDiscard: mockOnDiscard,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("framer-motion", () => ({
|
||||
AnimatePresence: ({ children }: { children: React.ReactNode }) => (
|
||||
<>{children}</>
|
||||
),
|
||||
motion: {
|
||||
div: React.forwardRef(function MotionDiv(
|
||||
props: Record<string, unknown>,
|
||||
ref: React.Ref<HTMLDivElement>,
|
||||
) {
|
||||
const {
|
||||
children,
|
||||
initial: _initial,
|
||||
animate: _animate,
|
||||
exit: _exit,
|
||||
transition: _transition,
|
||||
...rest
|
||||
} = props as {
|
||||
children?: React.ReactNode;
|
||||
initial?: unknown;
|
||||
animate?: unknown;
|
||||
exit?: unknown;
|
||||
transition?: unknown;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
return (
|
||||
<div ref={ref} {...rest}>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
return render(<TooltipProvider>{ui}</TooltipProvider>);
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("DraftRecoveryPopup", () => {
|
||||
describe("when open with diff data", () => {
|
||||
it("shows the unsaved changes message", () => {
|
||||
renderWithProviders(<DraftRecoveryPopup isInitialLoadComplete={true} />);
|
||||
expect(screen.getByText("Unsaved changes found")).toBeDefined();
|
||||
});
|
||||
|
||||
it("displays diff summary", () => {
|
||||
renderWithProviders(<DraftRecoveryPopup isInitialLoadComplete={true} />);
|
||||
const text = document.body.textContent;
|
||||
expect(text).toContain("+1/~2 blocks");
|
||||
expect(text).toContain("+1/-1 connections");
|
||||
});
|
||||
|
||||
it("renders restore and discard buttons", () => {
|
||||
renderWithProviders(<DraftRecoveryPopup isInitialLoadComplete={true} />);
|
||||
expect(screen.getAllByText("Restore changes").length).toBeGreaterThan(0);
|
||||
expect(screen.getAllByText("Discard changes").length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("calls onLoad when restore is clicked", () => {
|
||||
renderWithProviders(<DraftRecoveryPopup isInitialLoadComplete={true} />);
|
||||
const buttons = screen.getAllByRole("button", {
|
||||
name: /restore changes/i,
|
||||
});
|
||||
fireEvent.click(buttons[0]);
|
||||
expect(mockOnLoad).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("calls onDiscard when discard is clicked", () => {
|
||||
renderWithProviders(<DraftRecoveryPopup isInitialLoadComplete={true} />);
|
||||
const buttons = screen.getAllByRole("button", {
|
||||
name: /discard changes/i,
|
||||
});
|
||||
fireEvent.click(buttons[0]);
|
||||
expect(mockOnDiscard).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
|
||||
describe("when closed", () => {
|
||||
it("renders nothing when isOpen is false", async () => {
|
||||
const { useDraftRecoveryPopup } = await import(
|
||||
"../components/DraftRecoveryDialog/useDraftRecoveryPopup"
|
||||
);
|
||||
vi.mocked(useDraftRecoveryPopup).mockReturnValue({
|
||||
isOpen: false,
|
||||
popupRef: { current: null },
|
||||
nodeCount: 0,
|
||||
edgeCount: 0,
|
||||
diff: null,
|
||||
savedAt: 0,
|
||||
onLoad: vi.fn(),
|
||||
onDiscard: vi.fn(),
|
||||
});
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<DraftRecoveryPopup isInitialLoadComplete={true} />,
|
||||
);
|
||||
expect(container.textContent).toBe("");
|
||||
});
|
||||
});
|
||||
|
||||
describe("when diff is null", () => {
|
||||
it("falls back to node/edge count display", async () => {
|
||||
const { useDraftRecoveryPopup } = await import(
|
||||
"../components/DraftRecoveryDialog/useDraftRecoveryPopup"
|
||||
);
|
||||
vi.mocked(useDraftRecoveryPopup).mockReturnValue({
|
||||
isOpen: true,
|
||||
popupRef: { current: null },
|
||||
nodeCount: 5,
|
||||
edgeCount: 1,
|
||||
diff: null,
|
||||
savedAt: Date.now(),
|
||||
onLoad: vi.fn(),
|
||||
onDiscard: vi.fn(),
|
||||
});
|
||||
|
||||
renderWithProviders(<DraftRecoveryPopup isInitialLoadComplete={true} />);
|
||||
const text = document.body.textContent;
|
||||
expect(text).toContain("5 blocks");
|
||||
expect(text).toContain("1 connection");
|
||||
});
|
||||
|
||||
it("uses singular for 1 block", async () => {
|
||||
const { useDraftRecoveryPopup } = await import(
|
||||
"../components/DraftRecoveryDialog/useDraftRecoveryPopup"
|
||||
);
|
||||
vi.mocked(useDraftRecoveryPopup).mockReturnValue({
|
||||
isOpen: true,
|
||||
popupRef: { current: null },
|
||||
nodeCount: 1,
|
||||
edgeCount: 0,
|
||||
diff: null,
|
||||
savedAt: Date.now(),
|
||||
onLoad: vi.fn(),
|
||||
onDiscard: vi.fn(),
|
||||
});
|
||||
|
||||
renderWithProviders(<DraftRecoveryPopup isInitialLoadComplete={true} />);
|
||||
const text = document.body.textContent;
|
||||
expect(text).toContain("1 block,");
|
||||
expect(text).toContain("0 connections");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,342 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
cleanup,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { useBlockMenuStore } from "../stores/blockMenuStore";
|
||||
import { useControlPanelStore } from "../stores/controlPanelStore";
|
||||
import { DefaultStateType } from "../components/NewControlPanel/NewBlockMenu/types";
|
||||
import { SearchEntryFilterAnyOfItem } from "@/app/api/__generated__/models/searchEntryFilterAnyOfItem";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mocks for heavy child components
|
||||
// ---------------------------------------------------------------------------
|
||||
vi.mock(
|
||||
"../components/NewControlPanel/NewBlockMenu/BlockMenuDefault/BlockMenuDefault",
|
||||
() => ({
|
||||
BlockMenuDefault: () => (
|
||||
<div data-testid="block-menu-default">Default Content</div>
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"../components/NewControlPanel/NewBlockMenu/BlockMenuSearch/BlockMenuSearch",
|
||||
() => ({
|
||||
BlockMenuSearch: () => (
|
||||
<div data-testid="block-menu-search">Search Results</div>
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
// Mock query client used by the search bar hook
|
||||
vi.mock("@/lib/react-query/queryClient", () => ({
|
||||
getQueryClient: () => ({
|
||||
invalidateQueries: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Reset stores before each test
|
||||
// ---------------------------------------------------------------------------
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
useBlockMenuStore.getState().reset();
|
||||
useBlockMenuStore.setState({
|
||||
filters: [],
|
||||
creators: [],
|
||||
creators_list: [],
|
||||
categoryCounts: {
|
||||
blocks: 0,
|
||||
integrations: 0,
|
||||
marketplace_agents: 0,
|
||||
my_agents: 0,
|
||||
},
|
||||
});
|
||||
useControlPanelStore.getState().reset();
|
||||
});
|
||||
|
||||
// ===========================================================================
|
||||
// Section 1: blockMenuStore unit tests
|
||||
// ===========================================================================
|
||||
describe("blockMenuStore", () => {
|
||||
describe("searchQuery", () => {
|
||||
it("defaults to an empty string", () => {
|
||||
expect(useBlockMenuStore.getState().searchQuery).toBe("");
|
||||
});
|
||||
|
||||
it("sets the search query", () => {
|
||||
useBlockMenuStore.getState().setSearchQuery("timer");
|
||||
expect(useBlockMenuStore.getState().searchQuery).toBe("timer");
|
||||
});
|
||||
});
|
||||
|
||||
describe("defaultState", () => {
|
||||
it("defaults to SUGGESTION", () => {
|
||||
expect(useBlockMenuStore.getState().defaultState).toBe(
|
||||
DefaultStateType.SUGGESTION,
|
||||
);
|
||||
});
|
||||
|
||||
it("sets the default state", () => {
|
||||
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
|
||||
expect(useBlockMenuStore.getState().defaultState).toBe(
|
||||
DefaultStateType.ALL_BLOCKS,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("filters", () => {
|
||||
it("defaults to an empty array", () => {
|
||||
expect(useBlockMenuStore.getState().filters).toEqual([]);
|
||||
});
|
||||
|
||||
it("adds a filter", () => {
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual([
|
||||
SearchEntryFilterAnyOfItem.blocks,
|
||||
]);
|
||||
});
|
||||
|
||||
it("removes a filter", () => {
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.setFilters([
|
||||
SearchEntryFilterAnyOfItem.blocks,
|
||||
SearchEntryFilterAnyOfItem.integrations,
|
||||
]);
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.removeFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual([
|
||||
SearchEntryFilterAnyOfItem.integrations,
|
||||
]);
|
||||
});
|
||||
|
||||
it("replaces all filters with setFilters", () => {
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.setFilters([SearchEntryFilterAnyOfItem.marketplace_agents]);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual([
|
||||
SearchEntryFilterAnyOfItem.marketplace_agents,
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("creators", () => {
|
||||
it("adds a creator", () => {
|
||||
useBlockMenuStore.getState().addCreator("alice");
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
|
||||
});
|
||||
|
||||
it("removes a creator", () => {
|
||||
useBlockMenuStore.getState().setCreators(["alice", "bob"]);
|
||||
useBlockMenuStore.getState().removeCreator("alice");
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["bob"]);
|
||||
});
|
||||
|
||||
it("replaces all creators with setCreators", () => {
|
||||
useBlockMenuStore.getState().addCreator("alice");
|
||||
useBlockMenuStore.getState().setCreators(["charlie"]);
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["charlie"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("categoryCounts", () => {
|
||||
it("sets category counts", () => {
|
||||
const counts = {
|
||||
blocks: 10,
|
||||
integrations: 5,
|
||||
marketplace_agents: 3,
|
||||
my_agents: 2,
|
||||
};
|
||||
useBlockMenuStore.getState().setCategoryCounts(counts);
|
||||
expect(useBlockMenuStore.getState().categoryCounts).toEqual(counts);
|
||||
});
|
||||
});
|
||||
|
||||
describe("searchId", () => {
|
||||
it("defaults to undefined", () => {
|
||||
expect(useBlockMenuStore.getState().searchId).toBeUndefined();
|
||||
});
|
||||
|
||||
it("sets and clears searchId", () => {
|
||||
useBlockMenuStore.getState().setSearchId("search-123");
|
||||
expect(useBlockMenuStore.getState().searchId).toBe("search-123");
|
||||
|
||||
useBlockMenuStore.getState().setSearchId(undefined);
|
||||
expect(useBlockMenuStore.getState().searchId).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("integration", () => {
|
||||
it("defaults to undefined", () => {
|
||||
expect(useBlockMenuStore.getState().integration).toBeUndefined();
|
||||
});
|
||||
|
||||
it("sets the integration", () => {
|
||||
useBlockMenuStore.getState().setIntegration("slack");
|
||||
expect(useBlockMenuStore.getState().integration).toBe("slack");
|
||||
});
|
||||
});
|
||||
|
||||
describe("reset", () => {
|
||||
it("resets searchQuery, searchId, defaultState, and integration", () => {
|
||||
useBlockMenuStore.getState().setSearchQuery("hello");
|
||||
useBlockMenuStore.getState().setSearchId("id-1");
|
||||
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
|
||||
useBlockMenuStore.getState().setIntegration("github");
|
||||
|
||||
useBlockMenuStore.getState().reset();
|
||||
|
||||
const state = useBlockMenuStore.getState();
|
||||
expect(state.searchQuery).toBe("");
|
||||
expect(state.searchId).toBeUndefined();
|
||||
expect(state.defaultState).toBe(DefaultStateType.SUGGESTION);
|
||||
expect(state.integration).toBeUndefined();
|
||||
});
|
||||
|
||||
it("does not reset filters or creators (by design)", () => {
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.setFilters([SearchEntryFilterAnyOfItem.blocks]);
|
||||
useBlockMenuStore.getState().setCreators(["alice"]);
|
||||
|
||||
useBlockMenuStore.getState().reset();
|
||||
|
||||
expect(useBlockMenuStore.getState().filters).toEqual([
|
||||
SearchEntryFilterAnyOfItem.blocks,
|
||||
]);
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ===========================================================================
|
||||
// Section 2: controlPanelStore unit tests
|
||||
// ===========================================================================
|
||||
describe("controlPanelStore", () => {
|
||||
it("defaults blockMenuOpen to false", () => {
|
||||
expect(useControlPanelStore.getState().blockMenuOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("sets blockMenuOpen", () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
expect(useControlPanelStore.getState().blockMenuOpen).toBe(true);
|
||||
});
|
||||
|
||||
it("sets forceOpenBlockMenu", () => {
|
||||
useControlPanelStore.getState().setForceOpenBlockMenu(true);
|
||||
expect(useControlPanelStore.getState().forceOpenBlockMenu).toBe(true);
|
||||
});
|
||||
|
||||
it("resets all control panel state", () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
useControlPanelStore.getState().setForceOpenBlockMenu(true);
|
||||
useControlPanelStore.getState().setSaveControlOpen(true);
|
||||
useControlPanelStore.getState().setForceOpenSave(true);
|
||||
|
||||
useControlPanelStore.getState().reset();
|
||||
|
||||
const state = useControlPanelStore.getState();
|
||||
expect(state.blockMenuOpen).toBe(false);
|
||||
expect(state.forceOpenBlockMenu).toBe(false);
|
||||
expect(state.saveControlOpen).toBe(false);
|
||||
expect(state.forceOpenSave).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
// ===========================================================================
|
||||
// Section 3: BlockMenuContent integration tests
|
||||
// ===========================================================================
|
||||
// We import BlockMenuContent directly to avoid dealing with the Popover wrapper.
|
||||
import { BlockMenuContent } from "../components/NewControlPanel/NewBlockMenu/BlockMenuContent/BlockMenuContent";
|
||||
|
||||
describe("BlockMenuContent", () => {
|
||||
it("shows BlockMenuDefault when there is no search query", () => {
|
||||
useBlockMenuStore.getState().setSearchQuery("");
|
||||
|
||||
render(<BlockMenuContent />);
|
||||
|
||||
expect(screen.getByTestId("block-menu-default")).toBeDefined();
|
||||
expect(screen.queryByTestId("block-menu-search")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows BlockMenuSearch when a search query is present", () => {
|
||||
useBlockMenuStore.getState().setSearchQuery("timer");
|
||||
|
||||
render(<BlockMenuContent />);
|
||||
|
||||
expect(screen.getByTestId("block-menu-search")).toBeDefined();
|
||||
expect(screen.queryByTestId("block-menu-default")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders the search bar", () => {
|
||||
render(<BlockMenuContent />);
|
||||
|
||||
expect(
|
||||
screen.getByPlaceholderText(
|
||||
"Blocks, Agents, Integrations or Keywords...",
|
||||
),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("switches from default to search view when store query changes", () => {
|
||||
const { rerender } = render(<BlockMenuContent />);
|
||||
expect(screen.getByTestId("block-menu-default")).toBeDefined();
|
||||
|
||||
// Simulate typing by setting the store directly
|
||||
useBlockMenuStore.getState().setSearchQuery("webhook");
|
||||
rerender(<BlockMenuContent />);
|
||||
|
||||
expect(screen.getByTestId("block-menu-search")).toBeDefined();
|
||||
expect(screen.queryByTestId("block-menu-default")).toBeNull();
|
||||
});
|
||||
|
||||
it("switches back to default view when search query is cleared", () => {
|
||||
useBlockMenuStore.getState().setSearchQuery("something");
|
||||
const { rerender } = render(<BlockMenuContent />);
|
||||
expect(screen.getByTestId("block-menu-search")).toBeDefined();
|
||||
|
||||
useBlockMenuStore.getState().setSearchQuery("");
|
||||
rerender(<BlockMenuContent />);
|
||||
|
||||
expect(screen.getByTestId("block-menu-default")).toBeDefined();
|
||||
expect(screen.queryByTestId("block-menu-search")).toBeNull();
|
||||
});
|
||||
|
||||
it("typing in the search bar updates the local input value", async () => {
|
||||
render(<BlockMenuContent />);
|
||||
|
||||
const input = screen.getByPlaceholderText(
|
||||
"Blocks, Agents, Integrations or Keywords...",
|
||||
);
|
||||
fireEvent.change(input, { target: { value: "slack" } });
|
||||
|
||||
expect((input as HTMLInputElement).value).toBe("slack");
|
||||
});
|
||||
|
||||
it("shows clear button when input has text and clears on click", async () => {
|
||||
render(<BlockMenuContent />);
|
||||
|
||||
const input = screen.getByPlaceholderText(
|
||||
"Blocks, Agents, Integrations or Keywords...",
|
||||
);
|
||||
fireEvent.change(input, { target: { value: "test" } });
|
||||
|
||||
// The clear button should appear
|
||||
const clearButton = screen.getByRole("button");
|
||||
fireEvent.click(clearButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect((input as HTMLInputElement).value).toBe("");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,270 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
cleanup,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { UseFormReturn, useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import * as z from "zod";
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { useControlPanelStore } from "../stores/controlPanelStore";
|
||||
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { NewSaveControl } from "../components/NewControlPanel/NewSaveControl/NewSaveControl";
|
||||
import { useNewSaveControl } from "../components/NewControlPanel/NewSaveControl/useNewSaveControl";
|
||||
|
||||
const formSchema = z.object({
|
||||
name: z.string().min(1, "Name is required").max(100),
|
||||
description: z.string().max(500),
|
||||
});
|
||||
|
||||
type SaveableGraphFormValues = z.infer<typeof formSchema>;
|
||||
|
||||
const mockHandleSave = vi.fn();
|
||||
|
||||
vi.mock(
|
||||
"../components/NewControlPanel/NewSaveControl/useNewSaveControl",
|
||||
() => ({
|
||||
useNewSaveControl: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
const mockUseNewSaveControl = vi.mocked(useNewSaveControl);
|
||||
|
||||
function createMockForm(
|
||||
defaults: SaveableGraphFormValues = { name: "", description: "" },
|
||||
): UseFormReturn<SaveableGraphFormValues> {
|
||||
const { result } = renderHook(() =>
|
||||
useForm<SaveableGraphFormValues>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: defaults,
|
||||
}),
|
||||
);
|
||||
return result.current;
|
||||
}
|
||||
|
||||
function setupMock(overrides: {
|
||||
isSaving?: boolean;
|
||||
graphVersion?: number;
|
||||
name?: string;
|
||||
description?: string;
|
||||
}) {
|
||||
const form = createMockForm({
|
||||
name: overrides.name ?? "",
|
||||
description: overrides.description ?? "",
|
||||
});
|
||||
|
||||
mockUseNewSaveControl.mockReturnValue({
|
||||
form,
|
||||
isSaving: overrides.isSaving ?? false,
|
||||
graphVersion: overrides.graphVersion,
|
||||
handleSave: mockHandleSave,
|
||||
});
|
||||
|
||||
return form;
|
||||
}
|
||||
|
||||
function resetStore() {
|
||||
useControlPanelStore.setState({
|
||||
blockMenuOpen: false,
|
||||
saveControlOpen: false,
|
||||
forceOpenBlockMenu: false,
|
||||
forceOpenSave: false,
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
cleanup();
|
||||
resetStore();
|
||||
mockHandleSave.mockReset();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
|
||||
describe("NewSaveControl", () => {
|
||||
it("renders save button trigger", () => {
|
||||
setupMock({});
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("save-control-save-button")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders name and description inputs when popover is open", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({});
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("save-control-name-input")).toBeDefined();
|
||||
expect(screen.getByTestId("save-control-description-input")).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not render popover content when closed", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: false });
|
||||
setupMock({});
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("save-control-name-input")).toBeNull();
|
||||
expect(screen.queryByTestId("save-control-description-input")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows version output when graphVersion is set", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({ graphVersion: 3 });
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
const versionInput = screen.getByTestId("save-control-version-output");
|
||||
expect(versionInput).toBeDefined();
|
||||
expect((versionInput as HTMLInputElement).disabled).toBe(true);
|
||||
});
|
||||
|
||||
it("hides version output when graphVersion is undefined", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({ graphVersion: undefined });
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("save-control-version-output")).toBeNull();
|
||||
});
|
||||
|
||||
it("enables save button when isSaving is false", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({ isSaving: false });
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
const saveButton = screen.getByTestId("save-control-save-agent-button");
|
||||
expect((saveButton as HTMLButtonElement).disabled).toBe(false);
|
||||
});
|
||||
|
||||
it("disables save button when isSaving is true", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({ isSaving: true });
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: /save agent/i });
|
||||
expect((saveButton as HTMLButtonElement).disabled).toBe(true);
|
||||
});
|
||||
|
||||
it("calls handleSave on form submission with valid data", async () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
const form = setupMock({ name: "My Agent", description: "A description" });
|
||||
|
||||
form.setValue("name", "My Agent");
|
||||
form.setValue("description", "A description");
|
||||
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
const saveButton = screen.getByTestId("save-control-save-agent-button");
|
||||
fireEvent.click(saveButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleSave).toHaveBeenCalledWith(
|
||||
{ name: "My Agent", description: "A description" },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("does not call handleSave when name is empty (validation fails)", async () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({ name: "", description: "" });
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
const saveButton = screen.getByTestId("save-control-save-agent-button");
|
||||
fireEvent.click(saveButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleSave).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("popover stays open when forceOpenSave is true", () => {
|
||||
useControlPanelStore.setState({
|
||||
saveControlOpen: false,
|
||||
forceOpenSave: true,
|
||||
});
|
||||
setupMock({});
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("save-control-name-input")).toBeDefined();
|
||||
});
|
||||
|
||||
it("allows typing in name and description inputs", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({});
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
const nameInput = screen.getByTestId(
|
||||
"save-control-name-input",
|
||||
) as HTMLInputElement;
|
||||
const descriptionInput = screen.getByTestId(
|
||||
"save-control-description-input",
|
||||
) as HTMLInputElement;
|
||||
|
||||
fireEvent.change(nameInput, { target: { value: "Test Agent" } });
|
||||
fireEvent.change(descriptionInput, {
|
||||
target: { value: "Test Description" },
|
||||
});
|
||||
|
||||
expect(nameInput.value).toBe("Test Agent");
|
||||
expect(descriptionInput.value).toBe("Test Description");
|
||||
});
|
||||
|
||||
it("displays save button text", () => {
|
||||
useControlPanelStore.setState({ saveControlOpen: true });
|
||||
setupMock({});
|
||||
render(
|
||||
<TooltipProvider>
|
||||
<NewSaveControl />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
|
||||
expect(screen.getByText("Save Agent")).toBeDefined();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,147 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import { render } from "@/tests/integrations/test-utils";
|
||||
import React from "react";
|
||||
import { useGraphStore } from "../stores/graphStore";
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/BuilderActions/components/RunGraph/useRunGraph",
|
||||
() => ({
|
||||
useRunGraph: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock(
|
||||
"@/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog",
|
||||
() => ({
|
||||
RunInputDialog: ({ isOpen }: { isOpen: boolean }) =>
|
||||
isOpen ? <div data-testid="run-input-dialog">Dialog</div> : null,
|
||||
}),
|
||||
);
|
||||
|
||||
// Must import after mocks
|
||||
import { useRunGraph } from "../components/BuilderActions/components/RunGraph/useRunGraph";
|
||||
import { RunGraph } from "../components/BuilderActions/components/RunGraph/RunGraph";
|
||||
|
||||
const mockUseRunGraph = vi.mocked(useRunGraph);
|
||||
|
||||
function createMockReturnValue(
|
||||
overrides: Partial<ReturnType<typeof useRunGraph>> = {},
|
||||
) {
|
||||
return {
|
||||
handleRunGraph: vi.fn(),
|
||||
handleStopGraph: vi.fn(),
|
||||
openRunInputDialog: false,
|
||||
setOpenRunInputDialog: vi.fn(),
|
||||
isExecutingGraph: false,
|
||||
isTerminatingGraph: false,
|
||||
isSaving: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
// RunGraph uses Tooltip which requires TooltipProvider
|
||||
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
|
||||
function renderRunGraph(flowID: string | null = "test-flow-id") {
|
||||
return render(
|
||||
<TooltipProvider>
|
||||
<RunGraph flowID={flowID} />
|
||||
</TooltipProvider>,
|
||||
);
|
||||
}
|
||||
|
||||
describe("RunGraph", () => {
|
||||
beforeEach(() => {
|
||||
cleanup();
|
||||
mockUseRunGraph.mockReturnValue(createMockReturnValue());
|
||||
useGraphStore.setState({ isGraphRunning: false });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
|
||||
it("renders an enabled button when flowID is provided", () => {
|
||||
renderRunGraph("test-flow-id");
|
||||
const button = screen.getByRole("button");
|
||||
expect((button as HTMLButtonElement).disabled).toBe(false);
|
||||
});
|
||||
|
||||
it("renders a disabled button when flowID is null", () => {
|
||||
renderRunGraph(null);
|
||||
const button = screen.getByRole("button");
|
||||
expect((button as HTMLButtonElement).disabled).toBe(true);
|
||||
});
|
||||
|
||||
it("disables the button when isExecutingGraph is true", () => {
|
||||
mockUseRunGraph.mockReturnValue(
|
||||
createMockReturnValue({ isExecutingGraph: true }),
|
||||
);
|
||||
renderRunGraph();
|
||||
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it("disables the button when isTerminatingGraph is true", () => {
|
||||
mockUseRunGraph.mockReturnValue(
|
||||
createMockReturnValue({ isTerminatingGraph: true }),
|
||||
);
|
||||
renderRunGraph();
|
||||
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it("disables the button when isSaving is true", () => {
|
||||
mockUseRunGraph.mockReturnValue(createMockReturnValue({ isSaving: true }));
|
||||
renderRunGraph();
|
||||
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it("uses data-id run-graph-button when not running", () => {
|
||||
renderRunGraph();
|
||||
const button = screen.getByRole("button");
|
||||
expect(button.getAttribute("data-id")).toBe("run-graph-button");
|
||||
});
|
||||
|
||||
it("uses data-id stop-graph-button when running", () => {
|
||||
useGraphStore.setState({ isGraphRunning: true });
|
||||
renderRunGraph();
|
||||
const button = screen.getByRole("button");
|
||||
expect(button.getAttribute("data-id")).toBe("stop-graph-button");
|
||||
});
|
||||
|
||||
it("calls handleRunGraph when clicked and graph is not running", () => {
|
||||
const handleRunGraph = vi.fn();
|
||||
mockUseRunGraph.mockReturnValue(createMockReturnValue({ handleRunGraph }));
|
||||
renderRunGraph();
|
||||
fireEvent.click(screen.getByRole("button"));
|
||||
expect(handleRunGraph).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("calls handleStopGraph when clicked and graph is running", () => {
|
||||
const handleStopGraph = vi.fn();
|
||||
mockUseRunGraph.mockReturnValue(createMockReturnValue({ handleStopGraph }));
|
||||
useGraphStore.setState({ isGraphRunning: true });
|
||||
renderRunGraph();
|
||||
fireEvent.click(screen.getByRole("button"));
|
||||
expect(handleStopGraph).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("renders RunInputDialog hidden by default", () => {
|
||||
renderRunGraph();
|
||||
expect(screen.queryByTestId("run-input-dialog")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders RunInputDialog when openRunInputDialog is true", () => {
|
||||
mockUseRunGraph.mockReturnValue(
|
||||
createMockReturnValue({ openRunInputDialog: true }),
|
||||
);
|
||||
renderRunGraph();
|
||||
expect(screen.getByTestId("run-input-dialog")).toBeDefined();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,246 @@
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import { useBlockMenuStore } from "../stores/blockMenuStore";
|
||||
import { DefaultStateType } from "../components/NewControlPanel/NewBlockMenu/types";
|
||||
import { SearchEntryFilterAnyOfItem } from "@/app/api/__generated__/models/searchEntryFilterAnyOfItem";
|
||||
import { StoreAgent } from "@/app/api/__generated__/models/storeAgent";
|
||||
import { SearchResponseItemsItem } from "@/app/api/__generated__/models/searchResponseItemsItem";
|
||||
|
||||
beforeEach(() => {
|
||||
useBlockMenuStore.setState({
|
||||
searchQuery: "",
|
||||
searchId: undefined,
|
||||
defaultState: DefaultStateType.SUGGESTION,
|
||||
integration: undefined,
|
||||
filters: [],
|
||||
creators: [],
|
||||
creators_list: [],
|
||||
categoryCounts: {
|
||||
blocks: 0,
|
||||
integrations: 0,
|
||||
marketplace_agents: 0,
|
||||
my_agents: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
describe("blockMenuStore", () => {
|
||||
describe("initial state", () => {
|
||||
it("has empty search and suggestion default state", () => {
|
||||
const state = useBlockMenuStore.getState();
|
||||
expect(state.searchQuery).toBe("");
|
||||
expect(state.searchId).toBeUndefined();
|
||||
expect(state.defaultState).toBe("suggestion");
|
||||
expect(state.integration).toBeUndefined();
|
||||
expect(state.filters).toEqual([]);
|
||||
expect(state.creators).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("search state", () => {
|
||||
it("sets search query", () => {
|
||||
useBlockMenuStore.getState().setSearchQuery("weather");
|
||||
expect(useBlockMenuStore.getState().searchQuery).toBe("weather");
|
||||
});
|
||||
|
||||
it("sets search id", () => {
|
||||
useBlockMenuStore.getState().setSearchId("abc-123");
|
||||
expect(useBlockMenuStore.getState().searchId).toBe("abc-123");
|
||||
});
|
||||
|
||||
it("clears search id", () => {
|
||||
useBlockMenuStore.getState().setSearchId("abc-123");
|
||||
useBlockMenuStore.getState().setSearchId(undefined);
|
||||
expect(useBlockMenuStore.getState().searchId).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("default state", () => {
|
||||
it("sets default state", () => {
|
||||
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
|
||||
expect(useBlockMenuStore.getState().defaultState).toBe(
|
||||
DefaultStateType.ALL_BLOCKS,
|
||||
);
|
||||
});
|
||||
|
||||
it("changes between states", () => {
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.setDefaultState(DefaultStateType.INTEGRATIONS);
|
||||
useBlockMenuStore.getState().setDefaultState(DefaultStateType.MY_AGENTS);
|
||||
expect(useBlockMenuStore.getState().defaultState).toBe(
|
||||
DefaultStateType.MY_AGENTS,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("integration", () => {
|
||||
it("sets integration", () => {
|
||||
useBlockMenuStore.getState().setIntegration("slack");
|
||||
expect(useBlockMenuStore.getState().integration).toBe("slack");
|
||||
});
|
||||
|
||||
it("clears integration", () => {
|
||||
useBlockMenuStore.getState().setIntegration("slack");
|
||||
useBlockMenuStore.getState().setIntegration(undefined);
|
||||
expect(useBlockMenuStore.getState().integration).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("filters", () => {
|
||||
it("adds a filter", () => {
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual(["blocks"]);
|
||||
});
|
||||
|
||||
it("adds multiple filters", () => {
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.addFilter(SearchEntryFilterAnyOfItem.integrations);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual([
|
||||
"blocks",
|
||||
"integrations",
|
||||
]);
|
||||
});
|
||||
|
||||
it("removes a filter", () => {
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.addFilter(SearchEntryFilterAnyOfItem.integrations);
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.removeFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual(["integrations"]);
|
||||
});
|
||||
|
||||
it("sets filters directly", () => {
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.setFilters([
|
||||
SearchEntryFilterAnyOfItem.my_agents,
|
||||
SearchEntryFilterAnyOfItem.marketplace_agents,
|
||||
]);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual([
|
||||
"my_agents",
|
||||
"marketplace_agents",
|
||||
]);
|
||||
});
|
||||
|
||||
it("removing a non-existent filter is a no-op", () => {
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
useBlockMenuStore
|
||||
.getState()
|
||||
.removeFilter(SearchEntryFilterAnyOfItem.integrations);
|
||||
expect(useBlockMenuStore.getState().filters).toEqual(["blocks"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("creators", () => {
|
||||
it("adds a creator", () => {
|
||||
useBlockMenuStore.getState().addCreator("alice");
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
|
||||
});
|
||||
|
||||
it("removes a creator", () => {
|
||||
useBlockMenuStore.getState().addCreator("alice");
|
||||
useBlockMenuStore.getState().addCreator("bob");
|
||||
useBlockMenuStore.getState().removeCreator("alice");
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["bob"]);
|
||||
});
|
||||
|
||||
it("sets creators directly", () => {
|
||||
useBlockMenuStore.getState().setCreators(["x", "y"]);
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["x", "y"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setCreatorsList", () => {
|
||||
it("extracts creators from store_agent items", () => {
|
||||
const items: SearchResponseItemsItem[] = [
|
||||
{
|
||||
slug: "agent-1",
|
||||
agent_name: "Agent 1",
|
||||
creator: "alice",
|
||||
} as StoreAgent,
|
||||
{
|
||||
slug: "agent-2",
|
||||
agent_name: "Agent 2",
|
||||
creator: "bob",
|
||||
} as StoreAgent,
|
||||
];
|
||||
|
||||
useBlockMenuStore.getState().setCreatorsList(items);
|
||||
const list = useBlockMenuStore.getState().creators_list;
|
||||
expect(list).toContain("alice");
|
||||
expect(list).toContain("bob");
|
||||
});
|
||||
|
||||
it("deduplicates creators across calls", () => {
|
||||
const items1 = [
|
||||
{
|
||||
slug: "a1",
|
||||
agent_name: "A1",
|
||||
creator: "alice",
|
||||
} as StoreAgent,
|
||||
] as SearchResponseItemsItem[];
|
||||
const items2 = [
|
||||
{
|
||||
slug: "a2",
|
||||
agent_name: "A2",
|
||||
creator: "alice",
|
||||
} as StoreAgent,
|
||||
] as SearchResponseItemsItem[];
|
||||
|
||||
useBlockMenuStore.getState().setCreatorsList(items1);
|
||||
useBlockMenuStore.getState().setCreatorsList(items2);
|
||||
|
||||
const aliceCount = useBlockMenuStore
|
||||
.getState()
|
||||
.creators_list.filter((c) => c === "alice").length;
|
||||
expect(aliceCount).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("categoryCounts", () => {
|
||||
it("sets category counts", () => {
|
||||
const counts = {
|
||||
blocks: 10,
|
||||
integrations: 5,
|
||||
marketplace_agents: 3,
|
||||
my_agents: 2,
|
||||
};
|
||||
useBlockMenuStore.getState().setCategoryCounts(counts);
|
||||
expect(useBlockMenuStore.getState().categoryCounts).toEqual(counts);
|
||||
});
|
||||
});
|
||||
|
||||
describe("reset", () => {
|
||||
it("resets search query, searchId, defaultState, and integration", () => {
|
||||
useBlockMenuStore.getState().setSearchQuery("test");
|
||||
useBlockMenuStore.getState().setSearchId("id-1");
|
||||
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
|
||||
useBlockMenuStore.getState().setIntegration("slack");
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
useBlockMenuStore.getState().addCreator("alice");
|
||||
|
||||
useBlockMenuStore.getState().reset();
|
||||
|
||||
const state = useBlockMenuStore.getState();
|
||||
expect(state.searchQuery).toBe("");
|
||||
expect(state.searchId).toBeUndefined();
|
||||
expect(state.defaultState).toBe("suggestion");
|
||||
expect(state.integration).toBeUndefined();
|
||||
});
|
||||
|
||||
it("does not clear filters or creators", () => {
|
||||
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
|
||||
useBlockMenuStore.getState().addCreator("alice");
|
||||
|
||||
useBlockMenuStore.getState().reset();
|
||||
|
||||
expect(useBlockMenuStore.getState().filters).toEqual(["blocks"]);
|
||||
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,104 @@
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import { useControlPanelStore } from "../stores/controlPanelStore";
|
||||
|
||||
beforeEach(() => {
|
||||
useControlPanelStore.getState().reset();
|
||||
});
|
||||
|
||||
describe("controlPanelStore", () => {
|
||||
describe("initial state", () => {
|
||||
it("starts with all panels closed", () => {
|
||||
const state = useControlPanelStore.getState();
|
||||
expect(state.blockMenuOpen).toBe(false);
|
||||
expect(state.saveControlOpen).toBe(false);
|
||||
expect(state.forceOpenBlockMenu).toBe(false);
|
||||
expect(state.forceOpenSave).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setBlockMenuOpen", () => {
|
||||
it("opens the block menu", () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
expect(useControlPanelStore.getState().blockMenuOpen).toBe(true);
|
||||
});
|
||||
|
||||
it("closes the block menu", () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
useControlPanelStore.getState().setBlockMenuOpen(false);
|
||||
expect(useControlPanelStore.getState().blockMenuOpen).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setSaveControlOpen", () => {
|
||||
it("opens the save control", () => {
|
||||
useControlPanelStore.getState().setSaveControlOpen(true);
|
||||
expect(useControlPanelStore.getState().saveControlOpen).toBe(true);
|
||||
});
|
||||
|
||||
it("closes the save control", () => {
|
||||
useControlPanelStore.getState().setSaveControlOpen(true);
|
||||
useControlPanelStore.getState().setSaveControlOpen(false);
|
||||
expect(useControlPanelStore.getState().saveControlOpen).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setForceOpenBlockMenu", () => {
|
||||
it("sets force open state", () => {
|
||||
useControlPanelStore.getState().setForceOpenBlockMenu(true);
|
||||
expect(useControlPanelStore.getState().forceOpenBlockMenu).toBe(true);
|
||||
});
|
||||
|
||||
it("does not affect blockMenuOpen", () => {
|
||||
useControlPanelStore.getState().setForceOpenBlockMenu(true);
|
||||
expect(useControlPanelStore.getState().blockMenuOpen).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setForceOpenSave", () => {
|
||||
it("sets force open state", () => {
|
||||
useControlPanelStore.getState().setForceOpenSave(true);
|
||||
expect(useControlPanelStore.getState().forceOpenSave).toBe(true);
|
||||
});
|
||||
|
||||
it("does not affect saveControlOpen", () => {
|
||||
useControlPanelStore.getState().setForceOpenSave(true);
|
||||
expect(useControlPanelStore.getState().saveControlOpen).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("independent panel state", () => {
|
||||
it("opening block menu does not affect save control", () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
expect(useControlPanelStore.getState().saveControlOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("opening save control does not affect block menu", () => {
|
||||
useControlPanelStore.getState().setSaveControlOpen(true);
|
||||
expect(useControlPanelStore.getState().blockMenuOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("both panels can be open simultaneously", () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
useControlPanelStore.getState().setSaveControlOpen(true);
|
||||
expect(useControlPanelStore.getState().blockMenuOpen).toBe(true);
|
||||
expect(useControlPanelStore.getState().saveControlOpen).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("reset", () => {
|
||||
it("resets all state to defaults", () => {
|
||||
useControlPanelStore.getState().setBlockMenuOpen(true);
|
||||
useControlPanelStore.getState().setSaveControlOpen(true);
|
||||
useControlPanelStore.getState().setForceOpenBlockMenu(true);
|
||||
useControlPanelStore.getState().setForceOpenSave(true);
|
||||
|
||||
useControlPanelStore.getState().reset();
|
||||
|
||||
const state = useControlPanelStore.getState();
|
||||
expect(state.blockMenuOpen).toBe(false);
|
||||
expect(state.saveControlOpen).toBe(false);
|
||||
expect(state.forceOpenBlockMenu).toBe(false);
|
||||
expect(state.forceOpenSave).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,257 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { BlockUIType } from "../components/types";
|
||||
|
||||
vi.mock("@/services/storage/local-storage", () => {
|
||||
const store: Record<string, string> = {};
|
||||
return {
|
||||
Key: { COPIED_FLOW_DATA: "COPIED_FLOW_DATA" },
|
||||
storage: {
|
||||
get: (key: string) => store[key] ?? null,
|
||||
set: (key: string, value: string) => {
|
||||
store[key] = value;
|
||||
},
|
||||
clean: (key: string) => {
|
||||
delete store[key];
|
||||
},
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
import { useCopyPasteStore } from "../stores/copyPasteStore";
|
||||
import { useNodeStore } from "../stores/nodeStore";
|
||||
import { useEdgeStore } from "../stores/edgeStore";
|
||||
import { useHistoryStore } from "../stores/historyStore";
|
||||
import { storage, Key } from "@/services/storage/local-storage";
|
||||
|
||||
function createTestNode(
|
||||
id: string,
|
||||
overrides: Partial<CustomNode> = {},
|
||||
): CustomNode {
|
||||
return {
|
||||
id,
|
||||
type: "custom",
|
||||
position: overrides.position ?? { x: 100, y: 200 },
|
||||
selected: overrides.selected,
|
||||
data: {
|
||||
hardcodedValues: {},
|
||||
title: `Node ${id}`,
|
||||
description: "test node",
|
||||
inputSchema: {},
|
||||
outputSchema: {},
|
||||
uiType: BlockUIType.STANDARD,
|
||||
block_id: `block-${id}`,
|
||||
costs: [],
|
||||
categories: [],
|
||||
...overrides.data,
|
||||
},
|
||||
} as CustomNode;
|
||||
}
|
||||
|
||||
describe("useCopyPasteStore", () => {
|
||||
beforeEach(() => {
|
||||
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
|
||||
useEdgeStore.setState({ edges: [] });
|
||||
useHistoryStore.getState().clear();
|
||||
storage.clean(Key.COPIED_FLOW_DATA);
|
||||
});
|
||||
|
||||
describe("copySelectedNodes", () => {
|
||||
it("copies a single selected node to localStorage", () => {
|
||||
const node = createTestNode("1", { selected: true });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
|
||||
const stored = storage.get(Key.COPIED_FLOW_DATA);
|
||||
expect(stored).not.toBeNull();
|
||||
|
||||
const parsed = JSON.parse(stored!);
|
||||
expect(parsed.nodes).toHaveLength(1);
|
||||
expect(parsed.nodes[0].id).toBe("1");
|
||||
expect(parsed.edges).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("copies only edges between selected nodes", () => {
|
||||
const nodeA = createTestNode("a", { selected: true });
|
||||
const nodeB = createTestNode("b", { selected: true });
|
||||
const nodeC = createTestNode("c", { selected: false });
|
||||
useNodeStore.setState({ nodes: [nodeA, nodeB, nodeC] });
|
||||
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
{
|
||||
id: "e-ab",
|
||||
source: "a",
|
||||
target: "b",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
},
|
||||
{
|
||||
id: "e-bc",
|
||||
source: "b",
|
||||
target: "c",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
},
|
||||
{
|
||||
id: "e-ac",
|
||||
source: "a",
|
||||
target: "c",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
|
||||
const parsed = JSON.parse(storage.get(Key.COPIED_FLOW_DATA)!);
|
||||
expect(parsed.nodes).toHaveLength(2);
|
||||
expect(parsed.edges).toHaveLength(1);
|
||||
expect(parsed.edges[0].id).toBe("e-ab");
|
||||
});
|
||||
|
||||
it("stores empty data when no nodes are selected", () => {
|
||||
const node = createTestNode("1", { selected: false });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
|
||||
const parsed = JSON.parse(storage.get(Key.COPIED_FLOW_DATA)!);
|
||||
expect(parsed.nodes).toHaveLength(0);
|
||||
expect(parsed.edges).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("pasteNodes", () => {
|
||||
it("creates new nodes with new IDs via incrementNodeCounter", () => {
|
||||
const node = createTestNode("orig", {
|
||||
selected: true,
|
||||
position: { x: 100, y: 200 },
|
||||
});
|
||||
useNodeStore.setState({ nodes: [node], nodeCounter: 5 });
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(2);
|
||||
|
||||
const pastedNode = nodes.find((n) => n.id !== "orig");
|
||||
expect(pastedNode).toBeDefined();
|
||||
expect(pastedNode!.id).not.toBe("orig");
|
||||
});
|
||||
|
||||
it("offsets pasted node positions by +50 x/y", () => {
|
||||
const node = createTestNode("orig", {
|
||||
selected: true,
|
||||
position: { x: 100, y: 200 },
|
||||
});
|
||||
useNodeStore.setState({ nodes: [node], nodeCounter: 5 });
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
const pastedNode = nodes.find((n) => n.id !== "orig");
|
||||
expect(pastedNode).toBeDefined();
|
||||
expect(pastedNode!.position).toEqual({ x: 150, y: 250 });
|
||||
});
|
||||
|
||||
it("preserves internal connections with remapped IDs", () => {
|
||||
const nodeA = createTestNode("a", {
|
||||
selected: true,
|
||||
position: { x: 0, y: 0 },
|
||||
});
|
||||
const nodeB = createTestNode("b", {
|
||||
selected: true,
|
||||
position: { x: 200, y: 0 },
|
||||
});
|
||||
useNodeStore.setState({ nodes: [nodeA, nodeB], nodeCounter: 0 });
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
{
|
||||
id: "e-ab",
|
||||
source: "a",
|
||||
target: "b",
|
||||
sourceHandle: "output",
|
||||
targetHandle: "input",
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
|
||||
const { edges } = useEdgeStore.getState();
|
||||
const newEdges = edges.filter((e) => e.id !== "e-ab");
|
||||
expect(newEdges).toHaveLength(1);
|
||||
|
||||
const newEdge = newEdges[0];
|
||||
expect(newEdge.source).not.toBe("a");
|
||||
expect(newEdge.target).not.toBe("b");
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
const pastedNodeIDs = nodes
|
||||
.filter((n) => n.id !== "a" && n.id !== "b")
|
||||
.map((n) => n.id);
|
||||
|
||||
expect(pastedNodeIDs).toContain(newEdge.source);
|
||||
expect(pastedNodeIDs).toContain(newEdge.target);
|
||||
});
|
||||
|
||||
it("deselects existing nodes and selects pasted ones", () => {
|
||||
const existingNode = createTestNode("existing", {
|
||||
selected: true,
|
||||
position: { x: 0, y: 0 },
|
||||
});
|
||||
const nodeToCopy = createTestNode("copy-me", {
|
||||
selected: true,
|
||||
position: { x: 100, y: 100 },
|
||||
});
|
||||
useNodeStore.setState({
|
||||
nodes: [existingNode, nodeToCopy],
|
||||
nodeCounter: 0,
|
||||
});
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
|
||||
// Deselect nodeToCopy, keep existingNode selected to verify deselection on paste
|
||||
useNodeStore.setState({
|
||||
nodes: [
|
||||
{ ...existingNode, selected: true },
|
||||
{ ...nodeToCopy, selected: false },
|
||||
],
|
||||
});
|
||||
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
const originalNodes = nodes.filter(
|
||||
(n) => n.id === "existing" || n.id === "copy-me",
|
||||
);
|
||||
const pastedNodes = nodes.filter(
|
||||
(n) => n.id !== "existing" && n.id !== "copy-me",
|
||||
);
|
||||
|
||||
originalNodes.forEach((n) => {
|
||||
expect(n.selected).toBe(false);
|
||||
});
|
||||
pastedNodes.forEach((n) => {
|
||||
expect(n.selected).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it("does nothing when clipboard is empty", () => {
|
||||
const node = createTestNode("1", { position: { x: 0, y: 0 } });
|
||||
useNodeStore.setState({ nodes: [node], nodeCounter: 0 });
|
||||
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
expect(nodes[0].id).toBe("1");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,751 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { MarkerType } from "@xyflow/react";
|
||||
import { useEdgeStore } from "../stores/edgeStore";
|
||||
import { useNodeStore } from "../stores/nodeStore";
|
||||
import { useHistoryStore } from "../stores/historyStore";
|
||||
import type { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
|
||||
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
import type { Link } from "@/app/api/__generated__/models/link";
|
||||
|
||||
function makeEdge(overrides: Partial<CustomEdge> & { id: string }): CustomEdge {
|
||||
return {
|
||||
type: "custom",
|
||||
source: "node-a",
|
||||
target: "node-b",
|
||||
sourceHandle: "output",
|
||||
targetHandle: "input",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeExecutionResult(
|
||||
overrides: Partial<NodeExecutionResult>,
|
||||
): NodeExecutionResult {
|
||||
return {
|
||||
user_id: "user-1",
|
||||
graph_id: "graph-1",
|
||||
graph_version: 1,
|
||||
graph_exec_id: "gexec-1",
|
||||
node_exec_id: "nexec-1",
|
||||
node_id: "node-1",
|
||||
block_id: "block-1",
|
||||
status: "INCOMPLETE",
|
||||
input_data: {},
|
||||
output_data: {},
|
||||
add_time: new Date(),
|
||||
queue_time: null,
|
||||
start_time: null,
|
||||
end_time: null,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
useEdgeStore.setState({ edges: [] });
|
||||
useNodeStore.setState({ nodes: [] });
|
||||
useHistoryStore.setState({ past: [], future: [] });
|
||||
});
|
||||
|
||||
describe("edgeStore", () => {
|
||||
describe("setEdges", () => {
|
||||
it("replaces all edges", () => {
|
||||
const edges = [
|
||||
makeEdge({ id: "e1" }),
|
||||
makeEdge({ id: "e2", source: "node-c" }),
|
||||
];
|
||||
|
||||
useEdgeStore.getState().setEdges(edges);
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(2);
|
||||
expect(useEdgeStore.getState().edges[0].id).toBe("e1");
|
||||
expect(useEdgeStore.getState().edges[1].id).toBe("e2");
|
||||
});
|
||||
});
|
||||
|
||||
describe("addEdge", () => {
|
||||
it("adds an edge and auto-generates an ID", () => {
|
||||
const result = useEdgeStore.getState().addEdge({
|
||||
source: "n1",
|
||||
target: "n2",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
});
|
||||
|
||||
expect(result.id).toBe("n1:out->n2:in");
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(1);
|
||||
expect(useEdgeStore.getState().edges[0].id).toBe("n1:out->n2:in");
|
||||
});
|
||||
|
||||
it("uses provided ID when given", () => {
|
||||
const result = useEdgeStore.getState().addEdge({
|
||||
id: "custom-id",
|
||||
source: "n1",
|
||||
target: "n2",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
});
|
||||
|
||||
expect(result.id).toBe("custom-id");
|
||||
});
|
||||
|
||||
it("sets type to custom and adds arrow marker", () => {
|
||||
const result = useEdgeStore.getState().addEdge({
|
||||
source: "n1",
|
||||
target: "n2",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
});
|
||||
|
||||
expect(result.type).toBe("custom");
|
||||
expect(result.markerEnd).toEqual({
|
||||
type: MarkerType.ArrowClosed,
|
||||
strokeWidth: 2,
|
||||
color: "#555",
|
||||
});
|
||||
});
|
||||
|
||||
it("rejects duplicate edges without adding", () => {
|
||||
useEdgeStore.getState().addEdge({
|
||||
source: "n1",
|
||||
target: "n2",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
});
|
||||
|
||||
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
|
||||
|
||||
const duplicate = useEdgeStore.getState().addEdge({
|
||||
source: "n1",
|
||||
target: "n2",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(1);
|
||||
expect(duplicate.id).toBe("n1:out->n2:in");
|
||||
expect(pushSpy).not.toHaveBeenCalled();
|
||||
|
||||
pushSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("pushes previous state to history store", () => {
|
||||
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
|
||||
|
||||
useEdgeStore.getState().addEdge({
|
||||
source: "n1",
|
||||
target: "n2",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
});
|
||||
|
||||
expect(pushSpy).toHaveBeenCalledWith({
|
||||
nodes: [],
|
||||
edges: [],
|
||||
});
|
||||
|
||||
pushSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("removeEdge", () => {
|
||||
it("removes an edge by ID", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [makeEdge({ id: "e1" }), makeEdge({ id: "e2" })],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().removeEdge("e1");
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(1);
|
||||
expect(useEdgeStore.getState().edges[0].id).toBe("e2");
|
||||
});
|
||||
|
||||
it("does nothing when removing a non-existent edge", () => {
|
||||
useEdgeStore.setState({ edges: [makeEdge({ id: "e1" })] });
|
||||
|
||||
useEdgeStore.getState().removeEdge("nonexistent");
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("pushes previous state to history store", () => {
|
||||
const existingEdges = [makeEdge({ id: "e1" })];
|
||||
useEdgeStore.setState({ edges: existingEdges });
|
||||
|
||||
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
|
||||
|
||||
useEdgeStore.getState().removeEdge("e1");
|
||||
|
||||
expect(pushSpy).toHaveBeenCalledWith({
|
||||
nodes: [],
|
||||
edges: existingEdges,
|
||||
});
|
||||
|
||||
pushSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("upsertMany", () => {
|
||||
it("inserts new edges", () => {
|
||||
useEdgeStore.setState({ edges: [makeEdge({ id: "e1" })] });
|
||||
|
||||
useEdgeStore.getState().upsertMany([makeEdge({ id: "e2" })]);
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("updates existing edges by ID", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [makeEdge({ id: "e1", source: "old-source" })],
|
||||
});
|
||||
|
||||
useEdgeStore
|
||||
.getState()
|
||||
.upsertMany([makeEdge({ id: "e1", source: "new-source" })]);
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(1);
|
||||
expect(useEdgeStore.getState().edges[0].source).toBe("new-source");
|
||||
});
|
||||
|
||||
it("handles mixed inserts and updates", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [makeEdge({ id: "e1", source: "old" })],
|
||||
});
|
||||
|
||||
useEdgeStore
|
||||
.getState()
|
||||
.upsertMany([
|
||||
makeEdge({ id: "e1", source: "updated" }),
|
||||
makeEdge({ id: "e2", source: "new" }),
|
||||
]);
|
||||
|
||||
const edges = useEdgeStore.getState().edges;
|
||||
expect(edges).toHaveLength(2);
|
||||
expect(edges.find((e) => e.id === "e1")?.source).toBe("updated");
|
||||
expect(edges.find((e) => e.id === "e2")?.source).toBe("new");
|
||||
});
|
||||
});
|
||||
|
||||
describe("removeEdgesByHandlePrefix", () => {
|
||||
it("removes edges targeting a node with matching handle prefix", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({ id: "e1", target: "node-b", targetHandle: "input_foo" }),
|
||||
makeEdge({ id: "e2", target: "node-b", targetHandle: "input_bar" }),
|
||||
makeEdge({
|
||||
id: "e3",
|
||||
target: "node-b",
|
||||
targetHandle: "other_handle",
|
||||
}),
|
||||
makeEdge({ id: "e4", target: "node-c", targetHandle: "input_foo" }),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().removeEdgesByHandlePrefix("node-b", "input_");
|
||||
|
||||
const edges = useEdgeStore.getState().edges;
|
||||
expect(edges).toHaveLength(2);
|
||||
expect(edges.map((e) => e.id).sort()).toEqual(["e3", "e4"]);
|
||||
});
|
||||
|
||||
it("does not remove edges where target does not match nodeId", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
source: "node-b",
|
||||
target: "node-c",
|
||||
targetHandle: "input_x",
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().removeEdgesByHandlePrefix("node-b", "input_");
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getNodeEdges", () => {
|
||||
it("returns edges where node is source", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
|
||||
makeEdge({ id: "e2", source: "node-c", target: "node-d" }),
|
||||
],
|
||||
});
|
||||
|
||||
const result = useEdgeStore.getState().getNodeEdges("node-a");
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe("e1");
|
||||
});
|
||||
|
||||
it("returns edges where node is target", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
|
||||
makeEdge({ id: "e2", source: "node-c", target: "node-d" }),
|
||||
],
|
||||
});
|
||||
|
||||
const result = useEdgeStore.getState().getNodeEdges("node-b");
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe("e1");
|
||||
});
|
||||
|
||||
it("returns edges for both source and target", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
|
||||
makeEdge({ id: "e2", source: "node-b", target: "node-c" }),
|
||||
makeEdge({ id: "e3", source: "node-d", target: "node-e" }),
|
||||
],
|
||||
});
|
||||
|
||||
const result = useEdgeStore.getState().getNodeEdges("node-b");
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result.map((e) => e.id).sort()).toEqual(["e1", "e2"]);
|
||||
});
|
||||
|
||||
it("returns empty array for unconnected node", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [makeEdge({ id: "e1", source: "node-a", target: "node-b" })],
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().getNodeEdges("node-z")).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isInputConnected", () => {
|
||||
it("returns true when target handle is connected", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: "input",
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().isInputConnected("node-b", "input")).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it("returns false when target handle is not connected", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: "input",
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().isInputConnected("node-b", "other")).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it("returns false when node is source not target", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
source: "node-b",
|
||||
target: "node-c",
|
||||
sourceHandle: "output",
|
||||
targetHandle: "input",
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().isInputConnected("node-b", "output")).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isOutputConnected", () => {
|
||||
it("returns true when source handle is connected", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
source: "node-a",
|
||||
sourceHandle: "output",
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
expect(
|
||||
useEdgeStore.getState().isOutputConnected("node-a", "output"),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false when source handle is not connected", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
source: "node-a",
|
||||
sourceHandle: "output",
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().isOutputConnected("node-a", "other")).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getBackendLinks", () => {
|
||||
it("converts edges to Link format", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
source: "n1",
|
||||
target: "n2",
|
||||
sourceHandle: "out",
|
||||
targetHandle: "in",
|
||||
data: { isStatic: true },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
const links = useEdgeStore.getState().getBackendLinks();
|
||||
|
||||
expect(links).toHaveLength(1);
|
||||
expect(links[0]).toEqual({
|
||||
id: "e1",
|
||||
source_id: "n1",
|
||||
sink_id: "n2",
|
||||
source_name: "out",
|
||||
sink_name: "in",
|
||||
is_static: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("addLinks", () => {
|
||||
it("converts Links to edges and adds them", () => {
|
||||
const links: Link[] = [
|
||||
{
|
||||
id: "link-1",
|
||||
source_id: "n1",
|
||||
sink_id: "n2",
|
||||
source_name: "out",
|
||||
sink_name: "in",
|
||||
is_static: false,
|
||||
},
|
||||
];
|
||||
|
||||
useEdgeStore.getState().addLinks(links);
|
||||
|
||||
const edges = useEdgeStore.getState().edges;
|
||||
expect(edges).toHaveLength(1);
|
||||
expect(edges[0].source).toBe("n1");
|
||||
expect(edges[0].target).toBe("n2");
|
||||
expect(edges[0].sourceHandle).toBe("out");
|
||||
expect(edges[0].targetHandle).toBe("in");
|
||||
expect(edges[0].data?.isStatic).toBe(false);
|
||||
});
|
||||
|
||||
it("adds multiple links", () => {
|
||||
const links: Link[] = [
|
||||
{
|
||||
id: "link-1",
|
||||
source_id: "n1",
|
||||
sink_id: "n2",
|
||||
source_name: "out",
|
||||
sink_name: "in",
|
||||
},
|
||||
{
|
||||
id: "link-2",
|
||||
source_id: "n3",
|
||||
sink_id: "n4",
|
||||
source_name: "result",
|
||||
sink_name: "value",
|
||||
},
|
||||
];
|
||||
|
||||
useEdgeStore.getState().addLinks(links);
|
||||
|
||||
expect(useEdgeStore.getState().edges).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getAllHandleIdsOfANode", () => {
|
||||
it("returns targetHandle values for edges targeting the node", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({ id: "e1", target: "node-b", targetHandle: "input_a" }),
|
||||
makeEdge({ id: "e2", target: "node-b", targetHandle: "input_b" }),
|
||||
makeEdge({ id: "e3", target: "node-c", targetHandle: "input_c" }),
|
||||
],
|
||||
});
|
||||
|
||||
const handles = useEdgeStore.getState().getAllHandleIdsOfANode("node-b");
|
||||
expect(handles).toEqual(["input_a", "input_b"]);
|
||||
});
|
||||
|
||||
it("returns empty array when no edges target the node", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [makeEdge({ id: "e1", source: "node-b", target: "node-c" })],
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().getAllHandleIdsOfANode("node-b")).toEqual(
|
||||
[],
|
||||
);
|
||||
});
|
||||
|
||||
it("returns empty string for edges with no targetHandle", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: undefined,
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
expect(useEdgeStore.getState().getAllHandleIdsOfANode("node-b")).toEqual([
|
||||
"",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("updateEdgeBeads", () => {
|
||||
it("updates bead counts for edges targeting the node", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: "input",
|
||||
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().updateEdgeBeads(
|
||||
"node-b",
|
||||
makeExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
status: "COMPLETED",
|
||||
input_data: { input: "some-value" },
|
||||
}),
|
||||
);
|
||||
|
||||
const edge = useEdgeStore.getState().edges[0];
|
||||
expect(edge.data?.beadUp).toBe(1);
|
||||
expect(edge.data?.beadDown).toBe(1);
|
||||
});
|
||||
|
||||
it("counts INCOMPLETE status in beadUp but not beadDown", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: "input",
|
||||
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().updateEdgeBeads(
|
||||
"node-b",
|
||||
makeExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
status: "INCOMPLETE",
|
||||
input_data: { input: "data" },
|
||||
}),
|
||||
);
|
||||
|
||||
const edge = useEdgeStore.getState().edges[0];
|
||||
expect(edge.data?.beadUp).toBe(1);
|
||||
expect(edge.data?.beadDown).toBe(0);
|
||||
});
|
||||
|
||||
it("does not modify edges not targeting the node", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-c",
|
||||
targetHandle: "input",
|
||||
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().updateEdgeBeads(
|
||||
"node-b",
|
||||
makeExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
status: "COMPLETED",
|
||||
input_data: { input: "data" },
|
||||
}),
|
||||
);
|
||||
|
||||
const edge = useEdgeStore.getState().edges[0];
|
||||
expect(edge.data?.beadUp).toBe(0);
|
||||
expect(edge.data?.beadDown).toBe(0);
|
||||
});
|
||||
|
||||
it("does not update edge when input_data has no matching handle", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: "input",
|
||||
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().updateEdgeBeads(
|
||||
"node-b",
|
||||
makeExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
status: "COMPLETED",
|
||||
input_data: { other_handle: "data" },
|
||||
}),
|
||||
);
|
||||
|
||||
const edge = useEdgeStore.getState().edges[0];
|
||||
expect(edge.data?.beadUp).toBe(0);
|
||||
expect(edge.data?.beadDown).toBe(0);
|
||||
});
|
||||
|
||||
it("accumulates beads across multiple executions", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: "input",
|
||||
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().updateEdgeBeads(
|
||||
"node-b",
|
||||
makeExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
status: "COMPLETED",
|
||||
input_data: { input: "data1" },
|
||||
}),
|
||||
);
|
||||
|
||||
useEdgeStore.getState().updateEdgeBeads(
|
||||
"node-b",
|
||||
makeExecutionResult({
|
||||
node_exec_id: "exec-2",
|
||||
status: "INCOMPLETE",
|
||||
input_data: { input: "data2" },
|
||||
}),
|
||||
);
|
||||
|
||||
const edge = useEdgeStore.getState().edges[0];
|
||||
expect(edge.data?.beadUp).toBe(2);
|
||||
expect(edge.data?.beadDown).toBe(1);
|
||||
});
|
||||
|
||||
it("handles static edges by setting beadUp to beadDown + 1", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
target: "node-b",
|
||||
targetHandle: "input",
|
||||
data: {
|
||||
isStatic: true,
|
||||
beadUp: 0,
|
||||
beadDown: 0,
|
||||
beadData: new Map(),
|
||||
},
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().updateEdgeBeads(
|
||||
"node-b",
|
||||
makeExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
status: "COMPLETED",
|
||||
input_data: { input: "data" },
|
||||
}),
|
||||
);
|
||||
|
||||
const edge = useEdgeStore.getState().edges[0];
|
||||
expect(edge.data?.beadUp).toBe(2);
|
||||
expect(edge.data?.beadDown).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("resetEdgeBeads", () => {
|
||||
it("resets all bead data on all edges", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
data: {
|
||||
beadUp: 5,
|
||||
beadDown: 3,
|
||||
beadData: new Map([["exec-1", "COMPLETED"]]),
|
||||
},
|
||||
}),
|
||||
makeEdge({
|
||||
id: "e2",
|
||||
data: {
|
||||
beadUp: 2,
|
||||
beadDown: 1,
|
||||
beadData: new Map([["exec-2", "INCOMPLETE"]]),
|
||||
},
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().resetEdgeBeads();
|
||||
|
||||
const edges = useEdgeStore.getState().edges;
|
||||
for (const edge of edges) {
|
||||
expect(edge.data?.beadUp).toBe(0);
|
||||
expect(edge.data?.beadDown).toBe(0);
|
||||
expect(edge.data?.beadData?.size).toBe(0);
|
||||
}
|
||||
});
|
||||
|
||||
it("preserves other edge data when resetting beads", () => {
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
makeEdge({
|
||||
id: "e1",
|
||||
data: {
|
||||
isStatic: true,
|
||||
edgeColorClass: "text-red-500",
|
||||
beadUp: 3,
|
||||
beadDown: 2,
|
||||
beadData: new Map(),
|
||||
},
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
useEdgeStore.getState().resetEdgeBeads();
|
||||
|
||||
const edge = useEdgeStore.getState().edges[0];
|
||||
expect(edge.data?.isStatic).toBe(true);
|
||||
expect(edge.data?.edgeColorClass).toBe("text-red-500");
|
||||
expect(edge.data?.beadUp).toBe(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,347 @@
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import { useGraphStore } from "../stores/graphStore";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
|
||||
function createTestGraphMeta(
|
||||
overrides: Partial<GraphMeta> & { id: string; name: string },
|
||||
): GraphMeta {
|
||||
return {
|
||||
version: 1,
|
||||
description: "",
|
||||
is_active: true,
|
||||
user_id: "test-user",
|
||||
created_at: new Date("2024-01-01T00:00:00Z"),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function resetStore() {
|
||||
useGraphStore.setState({
|
||||
graphExecutionStatus: undefined,
|
||||
isGraphRunning: false,
|
||||
inputSchema: null,
|
||||
credentialsInputSchema: null,
|
||||
outputSchema: null,
|
||||
availableSubGraphs: [],
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
});
|
||||
|
||||
describe("graphStore", () => {
|
||||
describe("execution status transitions", () => {
|
||||
it("handles QUEUED -> RUNNING -> COMPLETED transition", () => {
|
||||
const { setGraphExecutionStatus } = useGraphStore.getState();
|
||||
|
||||
setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
|
||||
expect(useGraphStore.getState().graphExecutionStatus).toBe(
|
||||
AgentExecutionStatus.QUEUED,
|
||||
);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
|
||||
setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
expect(useGraphStore.getState().graphExecutionStatus).toBe(
|
||||
AgentExecutionStatus.RUNNING,
|
||||
);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
|
||||
setGraphExecutionStatus(AgentExecutionStatus.COMPLETED);
|
||||
expect(useGraphStore.getState().graphExecutionStatus).toBe(
|
||||
AgentExecutionStatus.COMPLETED,
|
||||
);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
|
||||
it("handles QUEUED -> RUNNING -> FAILED transition", () => {
|
||||
const { setGraphExecutionStatus } = useGraphStore.getState();
|
||||
|
||||
setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
|
||||
setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
|
||||
setGraphExecutionStatus(AgentExecutionStatus.FAILED);
|
||||
expect(useGraphStore.getState().graphExecutionStatus).toBe(
|
||||
AgentExecutionStatus.FAILED,
|
||||
);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setGraphExecutionStatus auto-sets isGraphRunning", () => {
|
||||
it("sets isGraphRunning to true for RUNNING", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
});
|
||||
|
||||
it("sets isGraphRunning to true for QUEUED", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
});
|
||||
|
||||
it("sets isGraphRunning to false for COMPLETED", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.COMPLETED);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
|
||||
it("sets isGraphRunning to false for FAILED", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.FAILED);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
|
||||
it("sets isGraphRunning to false for TERMINATED", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.TERMINATED);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
|
||||
it("sets isGraphRunning to false for INCOMPLETE", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.INCOMPLETE);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
|
||||
it("sets isGraphRunning to false for undefined", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
|
||||
useGraphStore.getState().setGraphExecutionStatus(undefined);
|
||||
expect(useGraphStore.getState().graphExecutionStatus).toBeUndefined();
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setIsGraphRunning", () => {
|
||||
it("sets isGraphRunning independently of status", () => {
|
||||
useGraphStore.getState().setIsGraphRunning(true);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(true);
|
||||
|
||||
useGraphStore.getState().setIsGraphRunning(false);
|
||||
expect(useGraphStore.getState().isGraphRunning).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("schema management", () => {
|
||||
it("sets all three schemas via setGraphSchemas", () => {
|
||||
const input = { properties: { prompt: { type: "string" } } };
|
||||
const credentials = { properties: { apiKey: { type: "string" } } };
|
||||
const output = { properties: { result: { type: "string" } } };
|
||||
|
||||
useGraphStore.getState().setGraphSchemas(input, credentials, output);
|
||||
|
||||
const state = useGraphStore.getState();
|
||||
expect(state.inputSchema).toEqual(input);
|
||||
expect(state.credentialsInputSchema).toEqual(credentials);
|
||||
expect(state.outputSchema).toEqual(output);
|
||||
});
|
||||
|
||||
it("sets schemas to null", () => {
|
||||
const input = { properties: { prompt: { type: "string" } } };
|
||||
useGraphStore.getState().setGraphSchemas(input, null, null);
|
||||
|
||||
const state = useGraphStore.getState();
|
||||
expect(state.inputSchema).toEqual(input);
|
||||
expect(state.credentialsInputSchema).toBeNull();
|
||||
expect(state.outputSchema).toBeNull();
|
||||
});
|
||||
|
||||
it("overwrites previous schemas", () => {
|
||||
const first = { properties: { a: { type: "string" } } };
|
||||
const second = { properties: { b: { type: "number" } } };
|
||||
|
||||
useGraphStore.getState().setGraphSchemas(first, first, first);
|
||||
useGraphStore.getState().setGraphSchemas(second, null, second);
|
||||
|
||||
const state = useGraphStore.getState();
|
||||
expect(state.inputSchema).toEqual(second);
|
||||
expect(state.credentialsInputSchema).toBeNull();
|
||||
expect(state.outputSchema).toEqual(second);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasInputs", () => {
|
||||
it("returns false when inputSchema is null", () => {
|
||||
expect(useGraphStore.getState().hasInputs()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when inputSchema has no properties", () => {
|
||||
useGraphStore.getState().setGraphSchemas({}, null, null);
|
||||
expect(useGraphStore.getState().hasInputs()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when inputSchema has empty properties", () => {
|
||||
useGraphStore.getState().setGraphSchemas({ properties: {} }, null, null);
|
||||
expect(useGraphStore.getState().hasInputs()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when inputSchema has properties", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphSchemas(
|
||||
{ properties: { prompt: { type: "string" } } },
|
||||
null,
|
||||
null,
|
||||
);
|
||||
expect(useGraphStore.getState().hasInputs()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasCredentials", () => {
|
||||
it("returns false when credentialsInputSchema is null", () => {
|
||||
expect(useGraphStore.getState().hasCredentials()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when credentialsInputSchema has empty properties", () => {
|
||||
useGraphStore.getState().setGraphSchemas(null, { properties: {} }, null);
|
||||
expect(useGraphStore.getState().hasCredentials()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when credentialsInputSchema has properties", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphSchemas(
|
||||
null,
|
||||
{ properties: { apiKey: { type: "string" } } },
|
||||
null,
|
||||
);
|
||||
expect(useGraphStore.getState().hasCredentials()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasOutputs", () => {
|
||||
it("returns false when outputSchema is null", () => {
|
||||
expect(useGraphStore.getState().hasOutputs()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when outputSchema has empty properties", () => {
|
||||
useGraphStore.getState().setGraphSchemas(null, null, { properties: {} });
|
||||
expect(useGraphStore.getState().hasOutputs()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when outputSchema has properties", () => {
|
||||
useGraphStore.getState().setGraphSchemas(null, null, {
|
||||
properties: { result: { type: "string" } },
|
||||
});
|
||||
expect(useGraphStore.getState().hasOutputs()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("reset", () => {
|
||||
it("clears execution status and schemas but preserves outputSchema and availableSubGraphs", () => {
|
||||
const subGraphs: GraphMeta[] = [
|
||||
createTestGraphMeta({
|
||||
id: "sub-1",
|
||||
name: "Sub Graph",
|
||||
description: "A sub graph",
|
||||
}),
|
||||
];
|
||||
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphSchemas(
|
||||
{ properties: { a: {} } },
|
||||
{ properties: { b: {} } },
|
||||
{ properties: { c: {} } },
|
||||
);
|
||||
useGraphStore.getState().setAvailableSubGraphs(subGraphs);
|
||||
|
||||
useGraphStore.getState().reset();
|
||||
|
||||
const state = useGraphStore.getState();
|
||||
expect(state.graphExecutionStatus).toBeUndefined();
|
||||
expect(state.isGraphRunning).toBe(false);
|
||||
expect(state.inputSchema).toBeNull();
|
||||
expect(state.credentialsInputSchema).toBeNull();
|
||||
// reset does not clear outputSchema or availableSubGraphs
|
||||
expect(state.outputSchema).toEqual({ properties: { c: {} } });
|
||||
expect(state.availableSubGraphs).toEqual(subGraphs);
|
||||
});
|
||||
|
||||
it("is idempotent on fresh state", () => {
|
||||
useGraphStore.getState().reset();
|
||||
|
||||
const state = useGraphStore.getState();
|
||||
expect(state.graphExecutionStatus).toBeUndefined();
|
||||
expect(state.isGraphRunning).toBe(false);
|
||||
expect(state.inputSchema).toBeNull();
|
||||
expect(state.credentialsInputSchema).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("setAvailableSubGraphs", () => {
|
||||
it("sets sub-graphs list", () => {
|
||||
const graphs: GraphMeta[] = [
|
||||
createTestGraphMeta({
|
||||
id: "graph-1",
|
||||
name: "Graph One",
|
||||
description: "First graph",
|
||||
}),
|
||||
createTestGraphMeta({
|
||||
id: "graph-2",
|
||||
version: 2,
|
||||
name: "Graph Two",
|
||||
description: "Second graph",
|
||||
}),
|
||||
];
|
||||
|
||||
useGraphStore.getState().setAvailableSubGraphs(graphs);
|
||||
expect(useGraphStore.getState().availableSubGraphs).toEqual(graphs);
|
||||
});
|
||||
|
||||
it("replaces previous sub-graphs", () => {
|
||||
const first: GraphMeta[] = [createTestGraphMeta({ id: "a", name: "A" })];
|
||||
const second: GraphMeta[] = [
|
||||
createTestGraphMeta({ id: "b", name: "B" }),
|
||||
createTestGraphMeta({ id: "c", name: "C" }),
|
||||
];
|
||||
|
||||
useGraphStore.getState().setAvailableSubGraphs(first);
|
||||
expect(useGraphStore.getState().availableSubGraphs).toHaveLength(1);
|
||||
|
||||
useGraphStore.getState().setAvailableSubGraphs(second);
|
||||
expect(useGraphStore.getState().availableSubGraphs).toHaveLength(2);
|
||||
expect(useGraphStore.getState().availableSubGraphs).toEqual(second);
|
||||
});
|
||||
|
||||
it("can set empty sub-graphs list", () => {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setAvailableSubGraphs([createTestGraphMeta({ id: "x", name: "X" })]);
|
||||
useGraphStore.getState().setAvailableSubGraphs([]);
|
||||
expect(useGraphStore.getState().availableSubGraphs).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,407 @@
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import { useHistoryStore } from "../stores/historyStore";
|
||||
import { useNodeStore } from "../stores/nodeStore";
|
||||
import { useEdgeStore } from "../stores/edgeStore";
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
|
||||
|
||||
function createTestNode(
|
||||
id: string,
|
||||
overrides: Partial<CustomNode> = {},
|
||||
): CustomNode {
|
||||
return {
|
||||
id,
|
||||
type: "custom" as const,
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
hardcodedValues: {},
|
||||
title: `Node ${id}`,
|
||||
description: "",
|
||||
inputSchema: {},
|
||||
outputSchema: {},
|
||||
uiType: "STANDARD" as never,
|
||||
block_id: `block-${id}`,
|
||||
costs: [],
|
||||
categories: [],
|
||||
},
|
||||
...overrides,
|
||||
} as CustomNode;
|
||||
}
|
||||
|
||||
function createTestEdge(
|
||||
id: string,
|
||||
source: string,
|
||||
target: string,
|
||||
): CustomEdge {
|
||||
return {
|
||||
id,
|
||||
source,
|
||||
target,
|
||||
type: "custom" as const,
|
||||
} as CustomEdge;
|
||||
}
|
||||
|
||||
async function flushMicrotasks() {
|
||||
await new Promise<void>((resolve) => queueMicrotask(resolve));
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
useHistoryStore.getState().clear();
|
||||
useNodeStore.setState({ nodes: [] });
|
||||
useEdgeStore.setState({ edges: [] });
|
||||
});
|
||||
|
||||
describe("historyStore", () => {
|
||||
describe("undo/redo single action", () => {
|
||||
it("undoes a single pushed state", async () => {
|
||||
const node = createTestNode("1");
|
||||
|
||||
// Initialize history with node present as baseline
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().initializeHistory();
|
||||
|
||||
// Simulate a change: clear nodes
|
||||
useNodeStore.setState({ nodes: [] });
|
||||
|
||||
// Undo should restore to [node]
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node]);
|
||||
expect(useHistoryStore.getState().future).toHaveLength(1);
|
||||
expect(useHistoryStore.getState().future[0].nodes).toEqual([]);
|
||||
});
|
||||
|
||||
it("redoes after undo", async () => {
|
||||
const node = createTestNode("1");
|
||||
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().initializeHistory();
|
||||
|
||||
// Change: clear nodes
|
||||
useNodeStore.setState({ nodes: [] });
|
||||
|
||||
// Undo → back to [node]
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node]);
|
||||
|
||||
// Redo → back to []
|
||||
useHistoryStore.getState().redo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("undo/redo multiple actions", () => {
|
||||
it("undoes through multiple states in order", async () => {
|
||||
const node1 = createTestNode("1");
|
||||
const node2 = createTestNode("2");
|
||||
const node3 = createTestNode("3");
|
||||
|
||||
// Initialize with [node1] as baseline
|
||||
useNodeStore.setState({ nodes: [node1] });
|
||||
useHistoryStore.getState().initializeHistory();
|
||||
|
||||
// Second change: add node2, push pre-change state
|
||||
useNodeStore.setState({ nodes: [node1, node2] });
|
||||
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
// Third change: add node3, push pre-change state
|
||||
useNodeStore.setState({ nodes: [node1, node2, node3] });
|
||||
useHistoryStore
|
||||
.getState()
|
||||
.pushState({ nodes: [node1, node2], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
// Undo 1: back to [node1, node2]
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
|
||||
|
||||
// Undo 2: back to [node1]
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node1]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("undo past empty history", () => {
|
||||
it("does nothing when there is no history to undo", () => {
|
||||
useHistoryStore.getState().undo();
|
||||
|
||||
expect(useNodeStore.getState().nodes).toEqual([]);
|
||||
expect(useEdgeStore.getState().edges).toEqual([]);
|
||||
expect(useHistoryStore.getState().past).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("does nothing when current state equals last past entry", () => {
|
||||
expect(useHistoryStore.getState().canUndo()).toBe(false);
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
|
||||
expect(useHistoryStore.getState().past).toHaveLength(1);
|
||||
expect(useHistoryStore.getState().future).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("state consistency: undo after node add restores previous, redo restores added", () => {
|
||||
it("undo removes added node, redo restores it", async () => {
|
||||
const node = createTestNode("added");
|
||||
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([]);
|
||||
|
||||
useHistoryStore.getState().redo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("history limits", () => {
|
||||
it("does not grow past MAX_HISTORY (50)", async () => {
|
||||
for (let i = 0; i < 60; i++) {
|
||||
const node = createTestNode(`node-${i}`);
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().pushState({
|
||||
nodes: [createTestNode(`node-${i - 1}`)],
|
||||
edges: [],
|
||||
});
|
||||
await flushMicrotasks();
|
||||
}
|
||||
|
||||
expect(useHistoryStore.getState().past.length).toBeLessThanOrEqual(50);
|
||||
});
|
||||
});
|
||||
|
||||
describe("edge cases", () => {
|
||||
it("redo does nothing when future is empty", () => {
|
||||
const nodesBefore = useNodeStore.getState().nodes;
|
||||
const edgesBefore = useEdgeStore.getState().edges;
|
||||
|
||||
useHistoryStore.getState().redo();
|
||||
|
||||
expect(useNodeStore.getState().nodes).toEqual(nodesBefore);
|
||||
expect(useEdgeStore.getState().edges).toEqual(edgesBefore);
|
||||
});
|
||||
|
||||
it("interleaved undo/redo sequence", async () => {
|
||||
const node1 = createTestNode("1");
|
||||
const node2 = createTestNode("2");
|
||||
const node3 = createTestNode("3");
|
||||
|
||||
useNodeStore.setState({ nodes: [node1] });
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useNodeStore.setState({ nodes: [node1, node2] });
|
||||
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useNodeStore.setState({ nodes: [node1, node2, node3] });
|
||||
useHistoryStore.getState().pushState({
|
||||
nodes: [node1, node2],
|
||||
edges: [],
|
||||
});
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node1]);
|
||||
|
||||
useHistoryStore.getState().redo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node1]);
|
||||
|
||||
useHistoryStore.getState().redo();
|
||||
useHistoryStore.getState().redo();
|
||||
expect(useNodeStore.getState().nodes).toEqual([node1, node2, node3]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("canUndo / canRedo", () => {
|
||||
it("canUndo is false on fresh store", () => {
|
||||
expect(useHistoryStore.getState().canUndo()).toBe(false);
|
||||
});
|
||||
|
||||
it("canUndo is true when current state differs from last past entry", async () => {
|
||||
const node = createTestNode("1");
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
expect(useHistoryStore.getState().canUndo()).toBe(true);
|
||||
});
|
||||
|
||||
it("canRedo is false on fresh store", () => {
|
||||
expect(useHistoryStore.getState().canRedo()).toBe(false);
|
||||
});
|
||||
|
||||
it("canRedo is true after undo", async () => {
|
||||
const node = createTestNode("1");
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
|
||||
expect(useHistoryStore.getState().canRedo()).toBe(true);
|
||||
});
|
||||
|
||||
it("canRedo becomes false after redo exhausts future", async () => {
|
||||
const node = createTestNode("1");
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
useHistoryStore.getState().redo();
|
||||
|
||||
expect(useHistoryStore.getState().canRedo()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("pushState deduplication", () => {
|
||||
it("does not push a state identical to the last past entry", async () => {
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
expect(useHistoryStore.getState().past).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("does not push if state matches current node/edge store state", async () => {
|
||||
const node = createTestNode("1");
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useEdgeStore.setState({ edges: [] });
|
||||
|
||||
useHistoryStore.getState().pushState({ nodes: [node], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
expect(useHistoryStore.getState().past).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("initializeHistory", () => {
|
||||
it("resets history with current node/edge store state", async () => {
|
||||
const node = createTestNode("1");
|
||||
const edge = createTestEdge("e1", "1", "2");
|
||||
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useEdgeStore.setState({ edges: [edge] });
|
||||
|
||||
useNodeStore.setState({ nodes: [node, createTestNode("2")] });
|
||||
useHistoryStore.getState().pushState({ nodes: [node], edges: [edge] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().initializeHistory();
|
||||
|
||||
const { past, future } = useHistoryStore.getState();
|
||||
expect(past).toHaveLength(1);
|
||||
expect(past[0].nodes).toEqual(useNodeStore.getState().nodes);
|
||||
expect(past[0].edges).toEqual(useEdgeStore.getState().edges);
|
||||
expect(future).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("clear", () => {
|
||||
it("resets to empty initial state", async () => {
|
||||
const node = createTestNode("1");
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().clear();
|
||||
|
||||
const { past, future } = useHistoryStore.getState();
|
||||
expect(past).toEqual([{ nodes: [], edges: [] }]);
|
||||
expect(future).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("microtask batching", () => {
|
||||
it("only commits the first state when multiple pushState calls happen in the same tick", async () => {
|
||||
const node1 = createTestNode("1");
|
||||
const node2 = createTestNode("2");
|
||||
const node3 = createTestNode("3");
|
||||
|
||||
useNodeStore.setState({ nodes: [node1, node2, node3] });
|
||||
|
||||
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
|
||||
useHistoryStore.getState().pushState({ nodes: [node2], edges: [] });
|
||||
useHistoryStore
|
||||
.getState()
|
||||
.pushState({ nodes: [node1, node2], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
const { past } = useHistoryStore.getState();
|
||||
expect(past).toHaveLength(2);
|
||||
expect(past[1].nodes).toEqual([node1]);
|
||||
});
|
||||
|
||||
it("commits separately when pushState calls are in different ticks", async () => {
|
||||
const node1 = createTestNode("1");
|
||||
const node2 = createTestNode("2");
|
||||
|
||||
useNodeStore.setState({ nodes: [node1, node2] });
|
||||
|
||||
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().pushState({ nodes: [node2], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
const { past } = useHistoryStore.getState();
|
||||
expect(past).toHaveLength(3);
|
||||
expect(past[1].nodes).toEqual([node1]);
|
||||
expect(past[2].nodes).toEqual([node2]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("edges in undo/redo", () => {
|
||||
it("restores edges on undo and redo", async () => {
|
||||
const edge = createTestEdge("e1", "1", "2");
|
||||
useEdgeStore.setState({ edges: [edge] });
|
||||
|
||||
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useEdgeStore.getState().edges).toEqual([]);
|
||||
|
||||
useHistoryStore.getState().redo();
|
||||
expect(useEdgeStore.getState().edges).toEqual([edge]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("pushState clears future", () => {
|
||||
it("clears future when a new state is pushed after undo", async () => {
|
||||
const node1 = createTestNode("1");
|
||||
const node2 = createTestNode("2");
|
||||
const node3 = createTestNode("3");
|
||||
|
||||
// Initialize empty
|
||||
useHistoryStore.getState().initializeHistory();
|
||||
|
||||
// First change: set [node1]
|
||||
useNodeStore.setState({ nodes: [node1] });
|
||||
|
||||
// Second change: set [node1, node2], push pre-change [node1]
|
||||
useNodeStore.setState({ nodes: [node1, node2] });
|
||||
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
// Undo: back to [node1]
|
||||
useHistoryStore.getState().undo();
|
||||
expect(useHistoryStore.getState().future).toHaveLength(1);
|
||||
|
||||
// New diverging change: add node3 instead of node2
|
||||
useNodeStore.setState({ nodes: [node1, node3] });
|
||||
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
|
||||
await flushMicrotasks();
|
||||
|
||||
expect(useHistoryStore.getState().future).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,791 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { useNodeStore } from "../stores/nodeStore";
|
||||
import { useHistoryStore } from "../stores/historyStore";
|
||||
import { useEdgeStore } from "../stores/edgeStore";
|
||||
import { BlockUIType } from "../components/types";
|
||||
import type { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import type { CustomNodeData } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
|
||||
function createTestNode(overrides: {
|
||||
id: string;
|
||||
position?: { x: number; y: number };
|
||||
data?: Partial<CustomNodeData>;
|
||||
}): CustomNode {
|
||||
const defaults: CustomNodeData = {
|
||||
hardcodedValues: {},
|
||||
title: "Test Block",
|
||||
description: "A test block",
|
||||
inputSchema: {},
|
||||
outputSchema: {},
|
||||
uiType: BlockUIType.STANDARD,
|
||||
block_id: "test-block-id",
|
||||
costs: [],
|
||||
categories: [],
|
||||
};
|
||||
|
||||
return {
|
||||
id: overrides.id,
|
||||
type: "custom",
|
||||
position: overrides.position ?? { x: 0, y: 0 },
|
||||
data: { ...defaults, ...overrides.data },
|
||||
};
|
||||
}
|
||||
|
||||
function createExecutionResult(
|
||||
overrides: Partial<NodeExecutionResult> = {},
|
||||
): NodeExecutionResult {
|
||||
return {
|
||||
node_exec_id: overrides.node_exec_id ?? "exec-1",
|
||||
node_id: overrides.node_id ?? "1",
|
||||
graph_exec_id: overrides.graph_exec_id ?? "graph-exec-1",
|
||||
graph_id: overrides.graph_id ?? "graph-1",
|
||||
graph_version: overrides.graph_version ?? 1,
|
||||
user_id: overrides.user_id ?? "test-user",
|
||||
block_id: overrides.block_id ?? "block-1",
|
||||
status: overrides.status ?? "COMPLETED",
|
||||
input_data: overrides.input_data ?? { input_key: "input_value" },
|
||||
output_data: overrides.output_data ?? { output_key: ["output_value"] },
|
||||
add_time: overrides.add_time ?? new Date("2024-01-01T00:00:00Z"),
|
||||
queue_time: overrides.queue_time ?? new Date("2024-01-01T00:00:00Z"),
|
||||
start_time: overrides.start_time ?? new Date("2024-01-01T00:00:01Z"),
|
||||
end_time: overrides.end_time ?? new Date("2024-01-01T00:00:02Z"),
|
||||
};
|
||||
}
|
||||
|
||||
function resetStores() {
|
||||
useNodeStore.setState({
|
||||
nodes: [],
|
||||
nodeCounter: 0,
|
||||
nodeAdvancedStates: {},
|
||||
latestNodeInputData: {},
|
||||
latestNodeOutputData: {},
|
||||
accumulatedNodeInputData: {},
|
||||
accumulatedNodeOutputData: {},
|
||||
nodesInResolutionMode: new Set(),
|
||||
brokenEdgeIDs: new Map(),
|
||||
nodeResolutionData: new Map(),
|
||||
});
|
||||
useEdgeStore.setState({ edges: [] });
|
||||
useHistoryStore.setState({ past: [], future: [] });
|
||||
}
|
||||
|
||||
describe("nodeStore", () => {
|
||||
beforeEach(() => {
|
||||
resetStores();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("node lifecycle", () => {
|
||||
it("starts with empty nodes", () => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toEqual([]);
|
||||
});
|
||||
|
||||
it("adds a single node with addNode", () => {
|
||||
const node = createTestNode({ id: "1" });
|
||||
useNodeStore.getState().addNode(node);
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
expect(nodes[0].id).toBe("1");
|
||||
});
|
||||
|
||||
it("sets nodes with setNodes, replacing existing ones", () => {
|
||||
const node1 = createTestNode({ id: "1" });
|
||||
const node2 = createTestNode({ id: "2" });
|
||||
useNodeStore.getState().addNode(node1);
|
||||
|
||||
useNodeStore.getState().setNodes([node2]);
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
expect(nodes[0].id).toBe("2");
|
||||
});
|
||||
|
||||
it("removes nodes via onNodesChange", () => {
|
||||
const node = createTestNode({ id: "1" });
|
||||
useNodeStore.getState().setNodes([node]);
|
||||
|
||||
useNodeStore.getState().onNodesChange([{ type: "remove", id: "1" }]);
|
||||
|
||||
expect(useNodeStore.getState().nodes).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("updates node data with updateNodeData", () => {
|
||||
const node = createTestNode({ id: "1" });
|
||||
useNodeStore.getState().addNode(node);
|
||||
|
||||
useNodeStore.getState().updateNodeData("1", { title: "Updated Title" });
|
||||
|
||||
const updated = useNodeStore.getState().nodes[0];
|
||||
expect(updated.data.title).toBe("Updated Title");
|
||||
expect(updated.data.block_id).toBe("test-block-id");
|
||||
});
|
||||
|
||||
it("updateNodeData does not affect other nodes", () => {
|
||||
const node1 = createTestNode({ id: "1" });
|
||||
const node2 = createTestNode({
|
||||
id: "2",
|
||||
data: { title: "Node 2" },
|
||||
});
|
||||
useNodeStore.getState().setNodes([node1, node2]);
|
||||
|
||||
useNodeStore.getState().updateNodeData("1", { title: "Changed" });
|
||||
|
||||
expect(useNodeStore.getState().nodes[1].data.title).toBe("Node 2");
|
||||
});
|
||||
});
|
||||
|
||||
describe("bulk operations", () => {
|
||||
it("adds multiple nodes with addNodes", () => {
|
||||
const nodes = [
|
||||
createTestNode({ id: "1" }),
|
||||
createTestNode({ id: "2" }),
|
||||
createTestNode({ id: "3" }),
|
||||
];
|
||||
useNodeStore.getState().addNodes(nodes);
|
||||
|
||||
expect(useNodeStore.getState().nodes).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("removes multiple nodes via onNodesChange", () => {
|
||||
const nodes = [
|
||||
createTestNode({ id: "1" }),
|
||||
createTestNode({ id: "2" }),
|
||||
createTestNode({ id: "3" }),
|
||||
];
|
||||
useNodeStore.getState().setNodes(nodes);
|
||||
|
||||
useNodeStore.getState().onNodesChange([
|
||||
{ type: "remove", id: "1" },
|
||||
{ type: "remove", id: "3" },
|
||||
]);
|
||||
|
||||
const remaining = useNodeStore.getState().nodes;
|
||||
expect(remaining).toHaveLength(1);
|
||||
expect(remaining[0].id).toBe("2");
|
||||
});
|
||||
});
|
||||
|
||||
describe("nodeCounter", () => {
|
||||
it("starts at zero", () => {
|
||||
expect(useNodeStore.getState().nodeCounter).toBe(0);
|
||||
});
|
||||
|
||||
it("increments the counter", () => {
|
||||
useNodeStore.getState().incrementNodeCounter();
|
||||
expect(useNodeStore.getState().nodeCounter).toBe(1);
|
||||
|
||||
useNodeStore.getState().incrementNodeCounter();
|
||||
expect(useNodeStore.getState().nodeCounter).toBe(2);
|
||||
});
|
||||
|
||||
it("sets the counter to a specific value", () => {
|
||||
useNodeStore.getState().setNodeCounter(42);
|
||||
expect(useNodeStore.getState().nodeCounter).toBe(42);
|
||||
});
|
||||
});
|
||||
|
||||
describe("advanced states", () => {
|
||||
it("defaults to false for unknown node IDs", () => {
|
||||
expect(useNodeStore.getState().getShowAdvanced("unknown")).toBe(false);
|
||||
});
|
||||
|
||||
it("toggles advanced state", () => {
|
||||
useNodeStore.getState().toggleAdvanced("node-1");
|
||||
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(true);
|
||||
|
||||
useNodeStore.getState().toggleAdvanced("node-1");
|
||||
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(false);
|
||||
});
|
||||
|
||||
it("sets advanced state explicitly", () => {
|
||||
useNodeStore.getState().setShowAdvanced("node-1", true);
|
||||
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(true);
|
||||
|
||||
useNodeStore.getState().setShowAdvanced("node-1", false);
|
||||
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("convertCustomNodeToBackendNode", () => {
|
||||
it("converts a node with minimal data", () => {
|
||||
const node = createTestNode({
|
||||
id: "42",
|
||||
position: { x: 100, y: 200 },
|
||||
});
|
||||
|
||||
const backend = useNodeStore
|
||||
.getState()
|
||||
.convertCustomNodeToBackendNode(node);
|
||||
|
||||
expect(backend.id).toBe("42");
|
||||
expect(backend.block_id).toBe("test-block-id");
|
||||
expect(backend.input_default).toEqual({});
|
||||
expect(backend.metadata).toEqual({ position: { x: 100, y: 200 } });
|
||||
});
|
||||
|
||||
it("includes customized_name when present in metadata", () => {
|
||||
const node = createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
metadata: { customized_name: "My Custom Name" },
|
||||
},
|
||||
});
|
||||
|
||||
const backend = useNodeStore
|
||||
.getState()
|
||||
.convertCustomNodeToBackendNode(node);
|
||||
|
||||
expect(backend.metadata).toHaveProperty(
|
||||
"customized_name",
|
||||
"My Custom Name",
|
||||
);
|
||||
});
|
||||
|
||||
it("includes credentials_optional when present in metadata", () => {
|
||||
const node = createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
metadata: { credentials_optional: true },
|
||||
},
|
||||
});
|
||||
|
||||
const backend = useNodeStore
|
||||
.getState()
|
||||
.convertCustomNodeToBackendNode(node);
|
||||
|
||||
expect(backend.metadata).toHaveProperty("credentials_optional", true);
|
||||
});
|
||||
|
||||
it("prunes empty values from hardcodedValues", () => {
|
||||
const node = createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
hardcodedValues: { filled: "value", empty: "" },
|
||||
},
|
||||
});
|
||||
|
||||
const backend = useNodeStore
|
||||
.getState()
|
||||
.convertCustomNodeToBackendNode(node);
|
||||
|
||||
expect(backend.input_default).toEqual({ filled: "value" });
|
||||
expect(backend.input_default).not.toHaveProperty("empty");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getBackendNodes", () => {
|
||||
it("converts all nodes to backend format", () => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.setNodes([
|
||||
createTestNode({ id: "1", position: { x: 0, y: 0 } }),
|
||||
createTestNode({ id: "2", position: { x: 100, y: 100 } }),
|
||||
]);
|
||||
|
||||
const backendNodes = useNodeStore.getState().getBackendNodes();
|
||||
|
||||
expect(backendNodes).toHaveLength(2);
|
||||
expect(backendNodes[0].id).toBe("1");
|
||||
expect(backendNodes[1].id).toBe("2");
|
||||
});
|
||||
});
|
||||
|
||||
describe("node status", () => {
|
||||
it("returns undefined for a node with no status", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
expect(useNodeStore.getState().getNodeStatus("1")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("updates node status", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
|
||||
expect(useNodeStore.getState().getNodeStatus("1")).toBe("RUNNING");
|
||||
|
||||
useNodeStore.getState().updateNodeStatus("1", "COMPLETED");
|
||||
expect(useNodeStore.getState().getNodeStatus("1")).toBe("COMPLETED");
|
||||
});
|
||||
|
||||
it("cleans all node statuses", () => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
|
||||
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
|
||||
useNodeStore.getState().updateNodeStatus("2", "COMPLETED");
|
||||
|
||||
useNodeStore.getState().cleanNodesStatuses();
|
||||
|
||||
expect(useNodeStore.getState().getNodeStatus("1")).toBeUndefined();
|
||||
expect(useNodeStore.getState().getNodeStatus("2")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("updating status for non-existent node does not crash", () => {
|
||||
useNodeStore.getState().updateNodeStatus("nonexistent", "RUNNING");
|
||||
expect(
|
||||
useNodeStore.getState().getNodeStatus("nonexistent"),
|
||||
).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("execution result tracking", () => {
|
||||
it("returns empty array for node with no results", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
expect(useNodeStore.getState().getNodeExecutionResults("1")).toEqual([]);
|
||||
});
|
||||
|
||||
it("tracks a single execution result", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
const result = createExecutionResult({ node_id: "1" });
|
||||
|
||||
useNodeStore.getState().updateNodeExecutionResult("1", result);
|
||||
|
||||
const results = useNodeStore.getState().getNodeExecutionResults("1");
|
||||
expect(results).toHaveLength(1);
|
||||
expect(results[0].node_exec_id).toBe("exec-1");
|
||||
});
|
||||
|
||||
it("accumulates multiple execution results", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
input_data: { key: "val1" },
|
||||
output_data: { key: ["out1"] },
|
||||
}),
|
||||
);
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-2",
|
||||
input_data: { key: "val2" },
|
||||
output_data: { key: ["out2"] },
|
||||
}),
|
||||
);
|
||||
|
||||
expect(useNodeStore.getState().getNodeExecutionResults("1")).toHaveLength(
|
||||
2,
|
||||
);
|
||||
});
|
||||
|
||||
it("updates latest input/output data", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
input_data: { key: "first" },
|
||||
output_data: { key: ["first_out"] },
|
||||
}),
|
||||
);
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-2",
|
||||
input_data: { key: "second" },
|
||||
output_data: { key: ["second_out"] },
|
||||
}),
|
||||
);
|
||||
|
||||
expect(useNodeStore.getState().getLatestNodeInputData("1")).toEqual({
|
||||
key: "second",
|
||||
});
|
||||
expect(useNodeStore.getState().getLatestNodeOutputData("1")).toEqual({
|
||||
key: ["second_out"],
|
||||
});
|
||||
});
|
||||
|
||||
it("accumulates input/output data across results", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
input_data: { key: "val1" },
|
||||
output_data: { key: ["out1"] },
|
||||
}),
|
||||
);
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-2",
|
||||
input_data: { key: "val2" },
|
||||
output_data: { key: ["out2"] },
|
||||
}),
|
||||
);
|
||||
|
||||
const accInput = useNodeStore.getState().getAccumulatedNodeInputData("1");
|
||||
expect(accInput.key).toEqual(["val1", "val2"]);
|
||||
|
||||
const accOutput = useNodeStore
|
||||
.getState()
|
||||
.getAccumulatedNodeOutputData("1");
|
||||
expect(accOutput.key).toEqual(["out1", "out2"]);
|
||||
});
|
||||
|
||||
it("deduplicates execution results by node_exec_id", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
input_data: { key: "original" },
|
||||
output_data: { key: ["original_out"] },
|
||||
}),
|
||||
);
|
||||
useNodeStore.getState().updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({
|
||||
node_exec_id: "exec-1",
|
||||
input_data: { key: "updated" },
|
||||
output_data: { key: ["updated_out"] },
|
||||
}),
|
||||
);
|
||||
|
||||
const results = useNodeStore.getState().getNodeExecutionResults("1");
|
||||
expect(results).toHaveLength(1);
|
||||
expect(results[0].input_data).toEqual({ key: "updated" });
|
||||
});
|
||||
|
||||
it("returns the latest execution result", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({ node_exec_id: "exec-1" }),
|
||||
);
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({ node_exec_id: "exec-2" }),
|
||||
);
|
||||
|
||||
const latest = useNodeStore.getState().getLatestNodeExecutionResult("1");
|
||||
expect(latest?.node_exec_id).toBe("exec-2");
|
||||
});
|
||||
|
||||
it("returns undefined for latest result on unknown node", () => {
|
||||
expect(
|
||||
useNodeStore.getState().getLatestNodeExecutionResult("unknown"),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
it("clears all execution results", () => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeExecutionResult(
|
||||
"1",
|
||||
createExecutionResult({ node_exec_id: "exec-1" }),
|
||||
);
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeExecutionResult(
|
||||
"2",
|
||||
createExecutionResult({ node_exec_id: "exec-2" }),
|
||||
);
|
||||
|
||||
useNodeStore.getState().clearAllNodeExecutionResults();
|
||||
|
||||
expect(useNodeStore.getState().getNodeExecutionResults("1")).toEqual([]);
|
||||
expect(useNodeStore.getState().getNodeExecutionResults("2")).toEqual([]);
|
||||
expect(
|
||||
useNodeStore.getState().getLatestNodeInputData("1"),
|
||||
).toBeUndefined();
|
||||
expect(
|
||||
useNodeStore.getState().getLatestNodeOutputData("1"),
|
||||
).toBeUndefined();
|
||||
expect(useNodeStore.getState().getAccumulatedNodeInputData("1")).toEqual(
|
||||
{},
|
||||
);
|
||||
expect(useNodeStore.getState().getAccumulatedNodeOutputData("1")).toEqual(
|
||||
{},
|
||||
);
|
||||
});
|
||||
|
||||
it("returns empty object for accumulated data on unknown node", () => {
|
||||
expect(
|
||||
useNodeStore.getState().getAccumulatedNodeInputData("unknown"),
|
||||
).toEqual({});
|
||||
expect(
|
||||
useNodeStore.getState().getAccumulatedNodeOutputData("unknown"),
|
||||
).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe("getNodeBlockUIType", () => {
|
||||
it("returns the node UI type", () => {
|
||||
useNodeStore.getState().addNode(
|
||||
createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
uiType: BlockUIType.INPUT,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(useNodeStore.getState().getNodeBlockUIType("1")).toBe(
|
||||
BlockUIType.INPUT,
|
||||
);
|
||||
});
|
||||
|
||||
it("defaults to STANDARD for unknown node IDs", () => {
|
||||
expect(useNodeStore.getState().getNodeBlockUIType("unknown")).toBe(
|
||||
BlockUIType.STANDARD,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasWebhookNodes", () => {
|
||||
it("returns false when there are no webhook nodes", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
expect(useNodeStore.getState().hasWebhookNodes()).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when a WEBHOOK node exists", () => {
|
||||
useNodeStore.getState().addNode(
|
||||
createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
uiType: BlockUIType.WEBHOOK,
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(useNodeStore.getState().hasWebhookNodes()).toBe(true);
|
||||
});
|
||||
|
||||
it("returns true when a WEBHOOK_MANUAL node exists", () => {
|
||||
useNodeStore.getState().addNode(
|
||||
createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
uiType: BlockUIType.WEBHOOK_MANUAL,
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(useNodeStore.getState().hasWebhookNodes()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("node errors", () => {
|
||||
it("returns undefined for a node with no errors", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("sets and retrieves node errors", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
const errors = { field1: "required", field2: "invalid" };
|
||||
useNodeStore.getState().updateNodeErrors("1", errors);
|
||||
|
||||
expect(useNodeStore.getState().getNodeErrors("1")).toEqual(errors);
|
||||
});
|
||||
|
||||
it("clears errors for a specific node", () => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
|
||||
useNodeStore.getState().updateNodeErrors("1", { f: "err" });
|
||||
useNodeStore.getState().updateNodeErrors("2", { g: "err2" });
|
||||
|
||||
useNodeStore.getState().clearNodeErrors("1");
|
||||
|
||||
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
|
||||
expect(useNodeStore.getState().getNodeErrors("2")).toEqual({ g: "err2" });
|
||||
});
|
||||
|
||||
it("clears all node errors", () => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
|
||||
useNodeStore.getState().updateNodeErrors("1", { a: "err1" });
|
||||
useNodeStore.getState().updateNodeErrors("2", { b: "err2" });
|
||||
|
||||
useNodeStore.getState().clearAllNodeErrors();
|
||||
|
||||
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
|
||||
expect(useNodeStore.getState().getNodeErrors("2")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("sets errors by backend ID matching node id", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "backend-1" }));
|
||||
|
||||
useNodeStore
|
||||
.getState()
|
||||
.setNodeErrorsForBackendId("backend-1", { x: "error" });
|
||||
|
||||
expect(useNodeStore.getState().getNodeErrors("backend-1")).toEqual({
|
||||
x: "error",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("getHardCodedValues", () => {
|
||||
it("returns hardcoded values for a node", () => {
|
||||
useNodeStore.getState().addNode(
|
||||
createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
hardcodedValues: { key: "value" },
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(useNodeStore.getState().getHardCodedValues("1")).toEqual({
|
||||
key: "value",
|
||||
});
|
||||
});
|
||||
|
||||
it("returns empty object for unknown node", () => {
|
||||
expect(useNodeStore.getState().getHardCodedValues("unknown")).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe("credentials optional", () => {
|
||||
it("sets credentials_optional in node metadata", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore.getState().setCredentialsOptional("1", true);
|
||||
|
||||
const node = useNodeStore.getState().nodes[0];
|
||||
expect(node.data.metadata?.credentials_optional).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("resolution mode", () => {
|
||||
it("defaults to not in resolution mode", () => {
|
||||
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
|
||||
});
|
||||
|
||||
it("enters and exits resolution mode", () => {
|
||||
useNodeStore.getState().setNodeResolutionMode("1", true);
|
||||
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(true);
|
||||
|
||||
useNodeStore.getState().setNodeResolutionMode("1", false);
|
||||
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
|
||||
});
|
||||
|
||||
it("tracks broken edge IDs", () => {
|
||||
useNodeStore.getState().setBrokenEdgeIDs("node-1", ["edge-1", "edge-2"]);
|
||||
|
||||
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(true);
|
||||
expect(useNodeStore.getState().isEdgeBroken("edge-2")).toBe(true);
|
||||
expect(useNodeStore.getState().isEdgeBroken("edge-3")).toBe(false);
|
||||
});
|
||||
|
||||
it("removes individual broken edge IDs", () => {
|
||||
useNodeStore.getState().setBrokenEdgeIDs("node-1", ["edge-1", "edge-2"]);
|
||||
useNodeStore.getState().removeBrokenEdgeID("node-1", "edge-1");
|
||||
|
||||
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
|
||||
expect(useNodeStore.getState().isEdgeBroken("edge-2")).toBe(true);
|
||||
});
|
||||
|
||||
it("clears all resolution state", () => {
|
||||
useNodeStore.getState().setNodeResolutionMode("1", true);
|
||||
useNodeStore.getState().setBrokenEdgeIDs("1", ["edge-1"]);
|
||||
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
|
||||
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
|
||||
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
|
||||
});
|
||||
|
||||
it("cleans up broken edges when exiting resolution mode", () => {
|
||||
useNodeStore.getState().setNodeResolutionMode("1", true);
|
||||
useNodeStore.getState().setBrokenEdgeIDs("1", ["edge-1"]);
|
||||
|
||||
useNodeStore.getState().setNodeResolutionMode("1", false);
|
||||
|
||||
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("edge cases", () => {
|
||||
it("handles updating data on a non-existent node gracefully", () => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeData("nonexistent", { title: "New Title" });
|
||||
|
||||
expect(useNodeStore.getState().nodes).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("handles removing a non-existent node gracefully", () => {
|
||||
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
|
||||
|
||||
useNodeStore
|
||||
.getState()
|
||||
.onNodesChange([{ type: "remove", id: "nonexistent" }]);
|
||||
|
||||
expect(useNodeStore.getState().nodes).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("handles duplicate node IDs in addNodes", () => {
|
||||
useNodeStore.getState().addNodes([
|
||||
createTestNode({
|
||||
id: "1",
|
||||
data: { title: "First" },
|
||||
}),
|
||||
createTestNode({
|
||||
id: "1",
|
||||
data: { title: "Second" },
|
||||
}),
|
||||
]);
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(2);
|
||||
expect(nodes[0].data.title).toBe("First");
|
||||
expect(nodes[1].data.title).toBe("Second");
|
||||
});
|
||||
|
||||
it("updating node status mid-execution preserves other data", () => {
|
||||
useNodeStore.getState().addNode(
|
||||
createTestNode({
|
||||
id: "1",
|
||||
data: {
|
||||
title: "My Node",
|
||||
hardcodedValues: { key: "val" },
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
|
||||
|
||||
const node = useNodeStore.getState().nodes[0];
|
||||
expect(node.data.status).toBe("RUNNING");
|
||||
expect(node.data.title).toBe("My Node");
|
||||
expect(node.data.hardcodedValues).toEqual({ key: "val" });
|
||||
});
|
||||
|
||||
it("execution result for non-existent node does not add it", () => {
|
||||
useNodeStore
|
||||
.getState()
|
||||
.updateNodeExecutionResult(
|
||||
"nonexistent",
|
||||
createExecutionResult({ node_exec_id: "exec-1" }),
|
||||
);
|
||||
|
||||
expect(useNodeStore.getState().nodes).toHaveLength(0);
|
||||
expect(
|
||||
useNodeStore.getState().getNodeExecutionResults("nonexistent"),
|
||||
).toEqual([]);
|
||||
});
|
||||
|
||||
it("getBackendNodes returns empty array when no nodes exist", () => {
|
||||
expect(useNodeStore.getState().getBackendNodes()).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,118 @@
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import { useTutorialStore } from "../stores/tutorialStore";
|
||||
|
||||
beforeEach(() => {
|
||||
useTutorialStore.setState({
|
||||
isTutorialRunning: false,
|
||||
currentStep: 0,
|
||||
forceOpenRunInputDialog: false,
|
||||
tutorialInputValues: {},
|
||||
});
|
||||
});
|
||||
|
||||
describe("tutorialStore", () => {
|
||||
describe("initial state", () => {
|
||||
it("starts with tutorial not running at step 0", () => {
|
||||
const state = useTutorialStore.getState();
|
||||
expect(state.isTutorialRunning).toBe(false);
|
||||
expect(state.currentStep).toBe(0);
|
||||
expect(state.forceOpenRunInputDialog).toBe(false);
|
||||
expect(state.tutorialInputValues).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe("setIsTutorialRunning", () => {
|
||||
it("starts the tutorial", () => {
|
||||
useTutorialStore.getState().setIsTutorialRunning(true);
|
||||
expect(useTutorialStore.getState().isTutorialRunning).toBe(true);
|
||||
});
|
||||
|
||||
it("stops the tutorial", () => {
|
||||
useTutorialStore.getState().setIsTutorialRunning(true);
|
||||
useTutorialStore.getState().setIsTutorialRunning(false);
|
||||
expect(useTutorialStore.getState().isTutorialRunning).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setCurrentStep", () => {
|
||||
it("advances to a step", () => {
|
||||
useTutorialStore.getState().setCurrentStep(3);
|
||||
expect(useTutorialStore.getState().currentStep).toBe(3);
|
||||
});
|
||||
|
||||
it("can go back to a previous step", () => {
|
||||
useTutorialStore.getState().setCurrentStep(5);
|
||||
useTutorialStore.getState().setCurrentStep(2);
|
||||
expect(useTutorialStore.getState().currentStep).toBe(2);
|
||||
});
|
||||
|
||||
it("can reset to step 0", () => {
|
||||
useTutorialStore.getState().setCurrentStep(4);
|
||||
useTutorialStore.getState().setCurrentStep(0);
|
||||
expect(useTutorialStore.getState().currentStep).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setForceOpenRunInputDialog", () => {
|
||||
it("forces the dialog open", () => {
|
||||
useTutorialStore.getState().setForceOpenRunInputDialog(true);
|
||||
expect(useTutorialStore.getState().forceOpenRunInputDialog).toBe(true);
|
||||
});
|
||||
|
||||
it("closes the forced dialog", () => {
|
||||
useTutorialStore.getState().setForceOpenRunInputDialog(true);
|
||||
useTutorialStore.getState().setForceOpenRunInputDialog(false);
|
||||
expect(useTutorialStore.getState().forceOpenRunInputDialog).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setTutorialInputValues", () => {
|
||||
it("sets input values", () => {
|
||||
useTutorialStore
|
||||
.getState()
|
||||
.setTutorialInputValues({ topic: "AI agents" });
|
||||
expect(useTutorialStore.getState().tutorialInputValues).toEqual({
|
||||
topic: "AI agents",
|
||||
});
|
||||
});
|
||||
|
||||
it("replaces previous values entirely", () => {
|
||||
useTutorialStore.getState().setTutorialInputValues({ a: "1" });
|
||||
useTutorialStore.getState().setTutorialInputValues({ b: "2" });
|
||||
expect(useTutorialStore.getState().tutorialInputValues).toEqual({
|
||||
b: "2",
|
||||
});
|
||||
});
|
||||
|
||||
it("clears values with empty object", () => {
|
||||
useTutorialStore.getState().setTutorialInputValues({ x: "y" });
|
||||
useTutorialStore.getState().setTutorialInputValues({});
|
||||
expect(useTutorialStore.getState().tutorialInputValues).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe("step progression lifecycle", () => {
|
||||
it("simulates a full tutorial run", () => {
|
||||
useTutorialStore.getState().setIsTutorialRunning(true);
|
||||
expect(useTutorialStore.getState().isTutorialRunning).toBe(true);
|
||||
|
||||
useTutorialStore.getState().setCurrentStep(1);
|
||||
useTutorialStore.getState().setCurrentStep(2);
|
||||
useTutorialStore.getState().setCurrentStep(3);
|
||||
|
||||
useTutorialStore.getState().setForceOpenRunInputDialog(true);
|
||||
useTutorialStore
|
||||
.getState()
|
||||
.setTutorialInputValues({ prompt: "test prompt" });
|
||||
useTutorialStore.getState().setForceOpenRunInputDialog(false);
|
||||
|
||||
useTutorialStore.getState().setCurrentStep(4);
|
||||
useTutorialStore.getState().setIsTutorialRunning(false);
|
||||
useTutorialStore.getState().setCurrentStep(0);
|
||||
|
||||
const state = useTutorialStore.getState();
|
||||
expect(state.isTutorialRunning).toBe(false);
|
||||
expect(state.currentStep).toBe(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,567 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import { BlockUIType } from "../components/types";
|
||||
|
||||
// ---- Mocks ----
|
||||
|
||||
const mockGetViewport = vi.fn(() => ({ x: 0, y: 0, zoom: 1 }));
|
||||
|
||||
vi.mock("@xyflow/react", async () => {
|
||||
const actual = await vi.importActual("@xyflow/react");
|
||||
return {
|
||||
...actual,
|
||||
useReactFlow: vi.fn(() => ({
|
||||
getViewport: mockGetViewport,
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
const mockToast = vi.fn();
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
useToast: vi.fn(() => ({ toast: mockToast })),
|
||||
}));
|
||||
|
||||
let uuidCounter = 0;
|
||||
vi.mock("uuid", () => ({
|
||||
v4: vi.fn(() => `new-uuid-${++uuidCounter}`),
|
||||
}));
|
||||
|
||||
// Mock navigator.clipboard
|
||||
const mockWriteText = vi.fn(() => Promise.resolve());
|
||||
const mockReadText = vi.fn(() => Promise.resolve(""));
|
||||
|
||||
Object.defineProperty(navigator, "clipboard", {
|
||||
value: {
|
||||
writeText: mockWriteText,
|
||||
readText: mockReadText,
|
||||
},
|
||||
writable: true,
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
// Mock window.innerWidth / innerHeight for viewport centering calculations
|
||||
Object.defineProperty(window, "innerWidth", { value: 1000, writable: true });
|
||||
Object.defineProperty(window, "innerHeight", { value: 800, writable: true });
|
||||
|
||||
import { useCopyPaste } from "../components/FlowEditor/Flow/useCopyPaste";
|
||||
import { useNodeStore } from "../stores/nodeStore";
|
||||
import { useEdgeStore } from "../stores/edgeStore";
|
||||
import { useHistoryStore } from "../stores/historyStore";
|
||||
import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
|
||||
|
||||
const CLIPBOARD_PREFIX = "autogpt-flow-data:";
|
||||
|
||||
function createTestNode(
|
||||
id: string,
|
||||
overrides: Partial<CustomNode> = {},
|
||||
): CustomNode {
|
||||
return {
|
||||
id,
|
||||
type: "custom",
|
||||
position: overrides.position ?? { x: 100, y: 200 },
|
||||
selected: overrides.selected,
|
||||
data: {
|
||||
hardcodedValues: {},
|
||||
title: `Node ${id}`,
|
||||
description: "test node",
|
||||
inputSchema: {},
|
||||
outputSchema: {},
|
||||
uiType: BlockUIType.STANDARD,
|
||||
block_id: `block-${id}`,
|
||||
costs: [],
|
||||
categories: [],
|
||||
...overrides.data,
|
||||
},
|
||||
} as CustomNode;
|
||||
}
|
||||
|
||||
function createTestEdge(
|
||||
id: string,
|
||||
source: string,
|
||||
target: string,
|
||||
sourceHandle = "out",
|
||||
targetHandle = "in",
|
||||
): CustomEdge {
|
||||
return {
|
||||
id,
|
||||
source,
|
||||
target,
|
||||
sourceHandle,
|
||||
targetHandle,
|
||||
} as CustomEdge;
|
||||
}
|
||||
|
||||
function makeCopyEvent(): KeyboardEvent {
|
||||
return new KeyboardEvent("keydown", {
|
||||
key: "c",
|
||||
ctrlKey: true,
|
||||
bubbles: true,
|
||||
});
|
||||
}
|
||||
|
||||
function makePasteEvent(): KeyboardEvent {
|
||||
return new KeyboardEvent("keydown", {
|
||||
key: "v",
|
||||
ctrlKey: true,
|
||||
bubbles: true,
|
||||
});
|
||||
}
|
||||
|
||||
function clipboardPayload(nodes: CustomNode[], edges: CustomEdge[]): string {
|
||||
return `${CLIPBOARD_PREFIX}${JSON.stringify({ nodes, edges })}`;
|
||||
}
|
||||
|
||||
describe("useCopyPaste", () => {
|
||||
beforeEach(() => {
|
||||
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
|
||||
useEdgeStore.setState({ edges: [] });
|
||||
useHistoryStore.getState().clear();
|
||||
mockWriteText.mockClear();
|
||||
mockReadText.mockClear();
|
||||
mockToast.mockClear();
|
||||
mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 });
|
||||
uuidCounter = 0;
|
||||
|
||||
// Ensure no input element is focused
|
||||
if (document.activeElement && document.activeElement !== document.body) {
|
||||
(document.activeElement as HTMLElement).blur();
|
||||
}
|
||||
});
|
||||
|
||||
describe("copy (Ctrl+C)", () => {
|
||||
it("copies a single selected node to clipboard with prefix", async () => {
|
||||
const node = createTestNode("1", { selected: true });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makeCopyEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockWriteText).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
const written = (mockWriteText.mock.calls as string[][])[0][0];
|
||||
expect(written.startsWith(CLIPBOARD_PREFIX)).toBe(true);
|
||||
|
||||
const parsed = JSON.parse(written.slice(CLIPBOARD_PREFIX.length));
|
||||
expect(parsed.nodes).toHaveLength(1);
|
||||
expect(parsed.nodes[0].id).toBe("1");
|
||||
expect(parsed.edges).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("shows a success toast after copying", async () => {
|
||||
const node = createTestNode("1", { selected: true });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makeCopyEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: "Copied successfully",
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("copies multiple connected nodes and preserves internal edges", async () => {
|
||||
const nodeA = createTestNode("a", { selected: true });
|
||||
const nodeB = createTestNode("b", { selected: true });
|
||||
const nodeC = createTestNode("c", { selected: false });
|
||||
useNodeStore.setState({ nodes: [nodeA, nodeB, nodeC] });
|
||||
|
||||
useEdgeStore.setState({
|
||||
edges: [
|
||||
createTestEdge("e-ab", "a", "b"),
|
||||
createTestEdge("e-bc", "b", "c"),
|
||||
],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makeCopyEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockWriteText).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(
|
||||
(mockWriteText.mock.calls as string[][])[0][0].slice(
|
||||
CLIPBOARD_PREFIX.length,
|
||||
),
|
||||
);
|
||||
expect(parsed.nodes).toHaveLength(2);
|
||||
expect(parsed.edges).toHaveLength(1);
|
||||
expect(parsed.edges[0].id).toBe("e-ab");
|
||||
});
|
||||
|
||||
it("drops external edges where one endpoint is not selected", async () => {
|
||||
const nodeA = createTestNode("a", { selected: true });
|
||||
const nodeB = createTestNode("b", { selected: false });
|
||||
useNodeStore.setState({ nodes: [nodeA, nodeB] });
|
||||
|
||||
useEdgeStore.setState({
|
||||
edges: [createTestEdge("e-ab", "a", "b")],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makeCopyEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockWriteText).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(
|
||||
(mockWriteText.mock.calls as string[][])[0][0].slice(
|
||||
CLIPBOARD_PREFIX.length,
|
||||
),
|
||||
);
|
||||
expect(parsed.nodes).toHaveLength(1);
|
||||
expect(parsed.edges).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("copies nothing when no nodes are selected", async () => {
|
||||
const node = createTestNode("1", { selected: false });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makeCopyEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockWriteText).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(
|
||||
(mockWriteText.mock.calls as string[][])[0][0].slice(
|
||||
CLIPBOARD_PREFIX.length,
|
||||
),
|
||||
);
|
||||
expect(parsed.nodes).toHaveLength(0);
|
||||
expect(parsed.edges).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("paste (Ctrl+V)", () => {
|
||||
it("creates new nodes with new UUIDs", async () => {
|
||||
const node = createTestNode("orig", {
|
||||
selected: true,
|
||||
position: { x: 100, y: 200 },
|
||||
});
|
||||
|
||||
mockReadText.mockResolvedValue(clipboardPayload([node], []));
|
||||
|
||||
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makePasteEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
});
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes[0].id).toBe("new-uuid-1");
|
||||
expect(nodes[0].id).not.toBe("orig");
|
||||
});
|
||||
|
||||
it("centers pasted nodes in the current viewport", async () => {
|
||||
// Viewport at origin, zoom 1 => center = (500, 400)
|
||||
mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 });
|
||||
|
||||
const node = createTestNode("orig", {
|
||||
selected: true,
|
||||
position: { x: 100, y: 100 },
|
||||
});
|
||||
|
||||
mockReadText.mockResolvedValue(clipboardPayload([node], []));
|
||||
|
||||
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makePasteEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
});
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
// Single node: center of bounds = (100, 100)
|
||||
// Viewport center = (500, 400)
|
||||
// Offset = (400, 300)
|
||||
// New position = (100 + 400, 100 + 300) = (500, 400)
|
||||
expect(nodes[0].position).toEqual({ x: 500, y: 400 });
|
||||
});
|
||||
|
||||
it("deselects existing nodes and selects pasted nodes", async () => {
|
||||
const existingNode = createTestNode("existing", {
|
||||
selected: true,
|
||||
position: { x: 0, y: 0 },
|
||||
});
|
||||
|
||||
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
|
||||
|
||||
const nodeToPaste = createTestNode("paste-me", {
|
||||
selected: false,
|
||||
position: { x: 100, y: 100 },
|
||||
});
|
||||
|
||||
mockReadText.mockResolvedValue(clipboardPayload([nodeToPaste], []));
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makePasteEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(2);
|
||||
});
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
const originalNode = nodes.find((n) => n.id === "existing");
|
||||
const pastedNode = nodes.find((n) => n.id !== "existing");
|
||||
|
||||
expect(originalNode!.selected).toBe(false);
|
||||
expect(pastedNode!.selected).toBe(true);
|
||||
});
|
||||
|
||||
it("remaps edge source/target IDs to newly created node IDs", async () => {
|
||||
const nodeA = createTestNode("a", {
|
||||
selected: true,
|
||||
position: { x: 0, y: 0 },
|
||||
});
|
||||
const nodeB = createTestNode("b", {
|
||||
selected: true,
|
||||
position: { x: 200, y: 0 },
|
||||
});
|
||||
const edge = createTestEdge("e-ab", "a", "b", "output", "input");
|
||||
|
||||
mockReadText.mockResolvedValue(clipboardPayload([nodeA, nodeB], [edge]));
|
||||
|
||||
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
|
||||
useEdgeStore.setState({ edges: [] });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makePasteEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(2);
|
||||
});
|
||||
|
||||
// Wait for edges to be added too
|
||||
await vi.waitFor(() => {
|
||||
const { edges } = useEdgeStore.getState();
|
||||
expect(edges).toHaveLength(1);
|
||||
});
|
||||
|
||||
const { edges } = useEdgeStore.getState();
|
||||
const newEdge = edges[0];
|
||||
|
||||
// Edge source/target should be remapped to new UUIDs, not "a"/"b"
|
||||
expect(newEdge.source).not.toBe("a");
|
||||
expect(newEdge.target).not.toBe("b");
|
||||
expect(newEdge.source).toBe("new-uuid-1");
|
||||
expect(newEdge.target).toBe("new-uuid-2");
|
||||
expect(newEdge.sourceHandle).toBe("output");
|
||||
expect(newEdge.targetHandle).toBe("input");
|
||||
});
|
||||
|
||||
it("does nothing when clipboard does not have the expected prefix", async () => {
|
||||
mockReadText.mockResolvedValue("some random text");
|
||||
|
||||
const existingNode = createTestNode("1", { position: { x: 0, y: 0 } });
|
||||
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makePasteEvent());
|
||||
});
|
||||
|
||||
// Give async operations time to settle
|
||||
await vi.waitFor(() => {
|
||||
expect(mockReadText).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Ensure no state changes happen after clipboard read
|
||||
await vi.waitFor(() => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
expect(nodes[0].id).toBe("1");
|
||||
});
|
||||
});
|
||||
|
||||
it("does nothing when clipboard is empty", async () => {
|
||||
mockReadText.mockResolvedValue("");
|
||||
|
||||
const existingNode = createTestNode("1", { position: { x: 0, y: 0 } });
|
||||
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makePasteEvent());
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockReadText).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Ensure no state changes happen after clipboard read
|
||||
await vi.waitFor(() => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
expect(nodes[0].id).toBe("1");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("input field focus guard", () => {
|
||||
it("ignores Ctrl+C when an input element is focused", async () => {
|
||||
const node = createTestNode("1", { selected: true });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
const input = document.createElement("input");
|
||||
document.body.appendChild(input);
|
||||
input.focus();
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makeCopyEvent());
|
||||
});
|
||||
|
||||
// Clipboard write should NOT be called
|
||||
expect(mockWriteText).not.toHaveBeenCalled();
|
||||
|
||||
document.body.removeChild(input);
|
||||
});
|
||||
|
||||
it("ignores Ctrl+V when a textarea element is focused", async () => {
|
||||
mockReadText.mockResolvedValue(
|
||||
clipboardPayload(
|
||||
[createTestNode("a", { position: { x: 0, y: 0 } })],
|
||||
[],
|
||||
),
|
||||
);
|
||||
|
||||
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
|
||||
|
||||
const textarea = document.createElement("textarea");
|
||||
document.body.appendChild(textarea);
|
||||
textarea.focus();
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makePasteEvent());
|
||||
});
|
||||
|
||||
expect(mockReadText).not.toHaveBeenCalled();
|
||||
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(0);
|
||||
|
||||
document.body.removeChild(textarea);
|
||||
});
|
||||
|
||||
it("ignores keypresses when a contenteditable element is focused", async () => {
|
||||
const node = createTestNode("1", { selected: true });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
const div = document.createElement("div");
|
||||
div.setAttribute("contenteditable", "true");
|
||||
document.body.appendChild(div);
|
||||
div.focus();
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
act(() => {
|
||||
result.current(makeCopyEvent());
|
||||
});
|
||||
|
||||
expect(mockWriteText).not.toHaveBeenCalled();
|
||||
|
||||
document.body.removeChild(div);
|
||||
});
|
||||
});
|
||||
|
||||
describe("meta key support (macOS)", () => {
|
||||
it("handles Cmd+C (metaKey) the same as Ctrl+C", async () => {
|
||||
const node = createTestNode("1", { selected: true });
|
||||
useNodeStore.setState({ nodes: [node] });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
const metaCopyEvent = new KeyboardEvent("keydown", {
|
||||
key: "c",
|
||||
metaKey: true,
|
||||
bubbles: true,
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current(metaCopyEvent);
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockWriteText).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it("handles Cmd+V (metaKey) the same as Ctrl+V", async () => {
|
||||
const node = createTestNode("orig", {
|
||||
selected: true,
|
||||
position: { x: 0, y: 0 },
|
||||
});
|
||||
mockReadText.mockResolvedValue(clipboardPayload([node], []));
|
||||
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
|
||||
|
||||
const { result } = renderHook(() => useCopyPaste());
|
||||
|
||||
const metaPasteEvent = new KeyboardEvent("keydown", {
|
||||
key: "v",
|
||||
metaKey: true,
|
||||
bubbles: true,
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current(metaPasteEvent);
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const { nodes } = useNodeStore.getState();
|
||||
expect(nodes).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,134 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
|
||||
const mockScreenToFlowPosition = vi.fn((pos: { x: number; y: number }) => pos);
|
||||
const mockFitView = vi.fn();
|
||||
|
||||
vi.mock("@xyflow/react", async () => {
|
||||
const actual = await vi.importActual("@xyflow/react");
|
||||
return {
|
||||
...actual,
|
||||
useReactFlow: () => ({
|
||||
screenToFlowPosition: mockScreenToFlowPosition,
|
||||
fitView: mockFitView,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const mockSetQueryStates = vi.fn();
|
||||
let mockQueryStateValues: {
|
||||
flowID: string | null;
|
||||
flowVersion: number | null;
|
||||
flowExecutionID: string | null;
|
||||
} = {
|
||||
flowID: null,
|
||||
flowVersion: null,
|
||||
flowExecutionID: null,
|
||||
};
|
||||
|
||||
vi.mock("nuqs", () => ({
|
||||
parseAsString: {},
|
||||
parseAsInteger: {},
|
||||
useQueryStates: vi.fn(() => [mockQueryStateValues, mockSetQueryStates]),
|
||||
}));
|
||||
|
||||
let mockGraphLoading = false;
|
||||
let mockBlocksLoading = false;
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/graphs/graphs", () => ({
|
||||
useGetV1GetSpecificGraph: vi.fn(() => ({
|
||||
data: undefined,
|
||||
isLoading: mockGraphLoading,
|
||||
})),
|
||||
useGetV1GetExecutionDetails: vi.fn(() => ({
|
||||
data: undefined,
|
||||
})),
|
||||
useGetV1ListUserGraphs: vi.fn(() => ({
|
||||
data: undefined,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/default/default", () => ({
|
||||
useGetV2GetSpecificBlocks: vi.fn(() => ({
|
||||
data: undefined,
|
||||
isLoading: mockBlocksLoading,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("@/app/api/helpers", () => ({
|
||||
okData: (res: { data: unknown }) => res?.data,
|
||||
}));
|
||||
|
||||
vi.mock("../components/helper", () => ({
|
||||
convertNodesPlusBlockInfoIntoCustomNodes: vi.fn(),
|
||||
}));
|
||||
|
||||
describe("useFlow", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true });
|
||||
mockGraphLoading = false;
|
||||
mockBlocksLoading = false;
|
||||
mockQueryStateValues = {
|
||||
flowID: null,
|
||||
flowVersion: null,
|
||||
flowExecutionID: null,
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe("loading states", () => {
|
||||
it("returns isFlowContentLoading true when graph is loading", async () => {
|
||||
mockGraphLoading = true;
|
||||
mockQueryStateValues = {
|
||||
flowID: "test-flow",
|
||||
flowVersion: 1,
|
||||
flowExecutionID: null,
|
||||
};
|
||||
|
||||
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
|
||||
const { result } = renderHook(() => useFlow());
|
||||
|
||||
expect(result.current.isFlowContentLoading).toBe(true);
|
||||
});
|
||||
|
||||
it("returns isFlowContentLoading true when blocks are loading", async () => {
|
||||
mockBlocksLoading = true;
|
||||
mockQueryStateValues = {
|
||||
flowID: "test-flow",
|
||||
flowVersion: 1,
|
||||
flowExecutionID: null,
|
||||
};
|
||||
|
||||
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
|
||||
const { result } = renderHook(() => useFlow());
|
||||
|
||||
expect(result.current.isFlowContentLoading).toBe(true);
|
||||
});
|
||||
|
||||
it("returns isFlowContentLoading false when neither is loading", async () => {
|
||||
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
|
||||
const { result } = renderHook(() => useFlow());
|
||||
|
||||
expect(result.current.isFlowContentLoading).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("initial load completion", () => {
|
||||
it("marks initial load complete for new flows without flowID", async () => {
|
||||
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
|
||||
const { result } = renderHook(() => useFlow());
|
||||
|
||||
expect(result.current.isInitialLoadComplete).toBe(false);
|
||||
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(300);
|
||||
});
|
||||
|
||||
expect(result.current.isInitialLoadComplete).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -3,7 +3,6 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { ExclamationMarkIcon } from "@phosphor-icons/react";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useState } from "react";
|
||||
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
|
||||
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -130,8 +129,6 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
case "tool-search_docs":
|
||||
case "tool-get_doc_page":
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-connect_integration":
|
||||
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
case "tool-continue_run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useState } from "react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
|
||||
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
|
||||
|
||||
type Props = {
|
||||
part: ToolUIPart;
|
||||
};
|
||||
|
||||
function parseJson(raw: unknown): unknown {
|
||||
if (typeof raw === "string") {
|
||||
try {
|
||||
return JSON.parse(raw);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
|
||||
return parsed as SetupRequirementsResponse;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseError(raw: unknown): string | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "message" in parsed) {
|
||||
return String((parsed as { message: unknown }).message);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function ConnectIntegrationTool({ part }: Props) {
|
||||
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
|
||||
const [isDismissed, setIsDismissed] = useState(false);
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
const output =
|
||||
part.state === "output-available"
|
||||
? parseOutput((part as { output?: unknown }).output)
|
||||
: null;
|
||||
|
||||
const errorMessage = isError
|
||||
? (parseError((part as { output?: unknown }).output) ??
|
||||
"Failed to connect integration")
|
||||
: null;
|
||||
|
||||
const rawProvider =
|
||||
(part as { input?: { provider?: string } }).input?.provider ?? "";
|
||||
const providerName =
|
||||
output?.setup_info?.agent_name ??
|
||||
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
|
||||
// prevent runaway text in the DOM.
|
||||
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
|
||||
|
||||
const label = isStreaming
|
||||
? `Connecting ${providerName}…`
|
||||
: isError
|
||||
? `Failed to connect ${providerName}`
|
||||
: output
|
||||
? `Connect ${output.setup_info?.agent_name ?? providerName}`
|
||||
: `Connect ${providerName}`;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<MorphingTextAnimation
|
||||
text={label}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isError && errorMessage && (
|
||||
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
|
||||
)}
|
||||
|
||||
{output && (
|
||||
<div className="mt-2">
|
||||
{isDismissed ? (
|
||||
<ContentMessage>Connected. Continuing…</ContentMessage>
|
||||
) : (
|
||||
<SetupRequirementsCard
|
||||
output={output}
|
||||
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
|
||||
retryInstruction="I've connected my account. Please continue."
|
||||
onComplete={() => setIsDismissed(true)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,16 +23,12 @@ interface Props {
|
||||
/** Override the label shown above the credentials section.
|
||||
* Defaults to "Credentials". */
|
||||
credentialsLabel?: string;
|
||||
/** Called after Proceed is clicked so the parent can persist the dismissed state
|
||||
* across remounts (avoids re-enabling the Proceed button on remount). */
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function SetupRequirementsCard({
|
||||
output,
|
||||
retryInstruction,
|
||||
credentialsLabel,
|
||||
onComplete,
|
||||
}: Props) {
|
||||
const { onSend } = useCopilotChatActions();
|
||||
|
||||
@@ -72,17 +68,13 @@ export function SetupRequirementsCard({
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
if (hasSent) {
|
||||
return <ContentMessage>Connected. Continuing…</ContentMessage>;
|
||||
}
|
||||
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
onComplete?.();
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ArrowRight, Lightning } from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useJumpBackIn } from "./useJumpBackIn";
|
||||
|
||||
export function JumpBackIn() {
|
||||
const { agent, isLoading } = useJumpBackIn();
|
||||
|
||||
if (isLoading || !agent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between rounded-large border border-zinc-200 bg-gradient-to-r from-zinc-50 to-white px-5 py-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
|
||||
<Lightning size={18} weight="fill" className="text-white" />
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<Text variant="small" className="text-zinc-500">
|
||||
Continue where you left off
|
||||
</Text>
|
||||
<Text variant="body-medium" className="text-zinc-900">
|
||||
{agent.name}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<NextLink href={`/library/agents/${agent.id}`}>
|
||||
<Button variant="primary" size="small" className="gap-1.5">
|
||||
Jump Back In
|
||||
<ArrowRight size={16} />
|
||||
</Button>
|
||||
</NextLink>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2ListLibraryAgents } from "@/app/api/__generated__/endpoints/library/library";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
export function useJumpBackIn() {
|
||||
const { data, isLoading } = useGetV2ListLibraryAgents(
|
||||
{
|
||||
page: 1,
|
||||
page_size: 1,
|
||||
sort_by: "updatedAt",
|
||||
},
|
||||
{
|
||||
query: { select: okData },
|
||||
},
|
||||
);
|
||||
|
||||
// The API doesn't include execution data by default (include_executions is
|
||||
// internal to the backend), so recent_executions is always empty here.
|
||||
// We use the most recently updated agent as the "jump back in" candidate
|
||||
// instead — updatedAt is the best available proxy for recent activity.
|
||||
const agent = data?.agents[0] ?? null;
|
||||
|
||||
return {
|
||||
agent,
|
||||
isLoading,
|
||||
};
|
||||
}
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import { useEffect, useState, useCallback } from "react";
|
||||
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
|
||||
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
|
||||
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
|
||||
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
|
||||
import { Tab } from "./components/LibraryTabs/LibraryTabs";
|
||||
@@ -39,7 +38,6 @@ export default function LibraryPage() {
|
||||
onAnimationComplete={handleFavoriteAnimationComplete}
|
||||
>
|
||||
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
|
||||
<JumpBackIn />
|
||||
<LibraryActionHeader setSearchTerm={setSearchTerm} />
|
||||
<LibraryAgentList
|
||||
searchTerm={searchTerm}
|
||||
|
||||
@@ -125,9 +125,9 @@ export function useCredentialsInput({
|
||||
if (hasAttemptedAutoSelect.current) return;
|
||||
hasAttemptedAutoSelect.current = true;
|
||||
|
||||
// Auto-select only when there is exactly one saved credential.
|
||||
// With multiple options the user must choose — regardless of optional/required.
|
||||
if (savedCreds.length > 1) return;
|
||||
// Auto-select if exactly one credential matches.
|
||||
// For optional fields with multiple options, let the user choose.
|
||||
if (isOptional && savedCreds.length > 1) return;
|
||||
|
||||
const cred = savedCreds[0];
|
||||
onSelectCredential({
|
||||
|
||||
Reference in New Issue
Block a user