Compare commits

..

76 Commits

Author SHA1 Message Date
Zamil Majdy
2a50cf17f4 fix: restore size="icon" on sidebar trigger button 2026-03-17 12:06:39 +07:00
Zamil Majdy
f0e595a447 Merge remote-tracking branch 'origin/dev' into feat/tracking-cost-block 2026-03-17 07:10:44 +07:00
Zamil Majdy
291eaf3341 Merge remote-tracking branch 'origin/dev' into feat/tracking-cost-block 2026-03-17 06:16:46 +07:00
Zamil Majdy
a8fe383481 fix: revert get_credits rename and remove unrelated changes
- Revert get_credits → get_credit_balance rename in DatabaseManager
  and DatabaseManagerClient to preserve existing RPC route name
- Keep DatabaseManagerAsyncClient credits additions (new, no route change)
- Remove text_utils.py (belongs to PR #12400)
- Revert sidebar.tsx change (unrelated)
- Restore removed comments in coercion tests
2026-03-17 06:14:41 +07:00
Zamil Majdy
c342290910 revert(store): also revert content_handlers_test.py from PR #12400 2026-03-17 05:19:26 +07:00
Zamil Majdy
7ff8bc8c5e Revert "fix(copilot): remove synthetic IDs from billing metadata"
This reverts commit b547c734cb.
2026-03-17 04:45:36 +07:00
Zamil Majdy
b547c734cb fix(copilot): remove synthetic IDs from billing metadata
The graph/node IDs in UsageTransactionMetadata are fabricated for copilot
block execution and have no real meaning. Keep only the genuinely useful
fields (block_id, block name, cost filter, reason).
2026-03-17 04:07:03 +07:00
Zamil Majdy
0a6d708411 refactor(copilot): remove redundant pre-compaction rate limit check
The check in routes.py already gates the request before any service
code runs. The SDK-level check reads the same Redis counters with no
token consumption in between, making it a no-op duplicate.
2026-03-17 01:57:48 +07:00
Zamil Majdy
b031cc55ec revert(copilot): restore synthetic IDs in billing metadata
The graph_exec_id, graph_id, and node_id are used for spend calculation.
2026-03-17 01:54:20 +07:00
Zamil Majdy
ca5a0e619e fix(copilot): update rate limit comment with realistic token estimates, remove synthetic IDs from billing metadata
- config.py: Update per-turn token estimate from 10-15K to 25-35K avg
  (accounting for context growth), adjust daily turn estimate to ~70-100
- helpers.py: Remove synthetic graph/node IDs from UsageTransactionMetadata
  since they're fake data; keep only real identifiers (node_exec_id, block_id)
2026-03-17 01:52:38 +07:00
Zamil Majdy
fd75e14eb8 fix(copilot): don't inflate completion tokens to 1 for tool-only responses
When streaming providers don't report usage, the fallback estimator
used max(..., 1) for completion tokens. For tool-only responses with
no text output, this artificially charged 1 phantom token.
2026-03-17 01:37:50 +07:00
Zamil Majdy
1476a580b2 fix(copilot): route credit operations through db_accessor pattern to fix Prisma connection error
The copilot helpers called get_user_credit_model() directly, which uses
Prisma — but the API process doesn't have Prisma connected. Now uses
credit_db() accessor (same pattern as chat_db, graph_db, etc.) that
falls back to the RPC client when Prisma isn't available.
2026-03-17 01:24:59 +07:00
Zamil Majdy
a6af72ed51 revert(data): remove formatting-only changes in credit.py
Reverts unrelated spacing/whitespace changes per reviewer feedback.
2026-03-17 01:10:32 +07:00
Zamil Majdy
dc3fa69f84 fix(platform): address Pwuts review feedback on PR #12385
- routes.py: Remove redundant `Security(auth.requires_user)` dependencies,
  change `user_id` params from `str | None` to `str` with `Security(auth.get_user_id)`,
  remove manual auth check
- content_handlers.py: Move helper functions after the functions that use them
- text_utils.py: Move `split_camelcase` to `backend/util/text.py`
- rate_limit.py: Rename `_PREFIX` to `_USAGE_KEY_PREFIX`, move private helpers
  after public functions
- response_model.py: Convert camelCase class attributes to snake_case with
  `serialization_alias` to preserve JSON contract
- token_tracking.py: Convert `%s` logger formatting to f-strings
- db_manager.py: Rename `get_credits` to `get_credit_balance`
- sdk/service.py: Add pre-compaction rate limit check
- config.py: Document rate limit check placement (pre-turn + pre-compaction)
- UsageLimits: Inline useUsageLimits hook directly into components,
  delete unnecessary wrapper file
2026-03-17 01:07:28 +07:00
Zamil Majdy
8682ef4b71 fix(platform): use lookaround assertions in camelCase regex to prevent polynomial backtracking
Replace capturing groups with zero-width lookahead/lookbehind assertions
in split_camelcase() to eliminate the polynomial regex DoS flagged by
CodeQL (code-scanning/227). The new patterns have no quantifiers inside
capturing groups, so they cannot backtrack.
2026-03-17 00:37:21 +07:00
Zamil Majdy
a3b7bf4424 fix(copilot): use pipeline(transaction=False) for independent rate-limit counters
The INCRBY+EXPIRE pairs operate on separate keys (daily vs weekly) with no
cross-key atomicity requirement. Dropping MULTI/EXEC avoids the overhead.
If the connection drops between INCRBY and EXPIRE, the key survives only
until the next date-based key rotation, so the memory-leak risk is negligible.
2026-03-17 00:31:26 +07:00
Zamil Majdy
80670345af fix(backend): remove dead-code get_credits/spend_credits wrappers and credit_db adapter 2026-03-16 06:22:07 +07:00
Zamil Majdy
3fc10e9fc0 refactor(frontend/copilot): extract UsagePanelContent to own file; accept optional now param in formatResetTime 2026-03-16 06:20:38 +07:00
Zamil Majdy
e71979e0b2 fix(frontend/copilot): parse JSON error body for structured 429 detection in useCopilotStream 2026-03-16 06:19:12 +07:00
Zamil Majdy
1456659dba fix(copilot): remove redundant limit check from rate-limit outer guard in routes 2026-03-16 06:17:11 +07:00
Zamil Majdy
acb0d41a35 fix(copilot): add max(0,...) guards to persist_and_record_usage inputs 2026-03-16 06:15:37 +07:00
Zamil Majdy
4a08482aa6 fix(copilot): catch all exceptions in post-exec credit charge to prevent false ErrorResponse 2026-03-15 23:26:01 +07:00
Zamil Majdy
20e410bd5a docs(copilot): add TODO for per-user rate limit support in config 2026-03-15 22:57:15 +07:00
Zamil Majdy
8e6d140580 fix(backend/copilot): document provider-dependent chunk.usage null behaviour
Not all OpenRouter-proxied providers honour stream_options={"include_usage":
True} — some always return chunk.usage=None. Add a comment at the guarded
check making this explicit and noting the tiktoken fallback activates in
that case (fail-open).
2026-03-15 22:39:03 +07:00
Zamil Majdy
d6cec14543 test(copilot): add unit tests for persist_and_record_usage in token_tracking
Adds 14 unit tests covering all branches of the shared token-usage
persistence helper: total_tokens semantics (prompt+completion, cache
excluded), session Usage record appending, multi-turn accumulation,
None session/user handling, rate-limit recording dispatch, fail-open
on Redis errors, and early-exit when tokens are zero.

Covers both baseline (no cache) and SDK (with cache breakdown) paths
so the ~100 lines of token_tracking.py have explicit test coverage.
2026-03-15 22:27:46 +07:00
Zamil Majdy
abd9dac288 fix(backend/copilot): mock get_user_credit_model instead of removed credit_db in tests 2026-03-15 13:47:41 +07:00
Zamil Majdy
6cc2c7a50d fix(frontend/copilot): remove dark mode classes, add 429 handling, use Next.js Link
- Remove all dark: Tailwind classes from copilot UsageLimits component
  and CoPilotUsageSection in credits page (no dark mode in copilot)
- Add 429 rate-limit detection in useCopilotStream onError with toast
  showing the reset time from the backend response
- Replace <a href> with Next.js <Link> in UsageLimits to avoid full
  page reload when navigating to credits page
2026-03-15 08:13:28 +07:00
Zamil Majdy
a0d534f24b fix(backend/copilot): address autogpt-reviewer should-fix items for rate limiting and token tracking
- Normalize total_tokens semantics: both baseline and SDK now compute
  total_tokens = prompt + completion, with cache fields kept separate
- Resolve credit model once in helpers.py execute_block to avoid
  duplicate get_user_credit_model() calls per block execution
- Add 429 test coverage: weekly rate limit and reset-time assertion tests
- Extract shared token tracking into token_tracking.py to DRY ~50 lines
  of duplicated usage persistence + rate-limit recording logic
- Add early return in check_rate_limit when both limits are 0 (unlimited)
  to skip unnecessary Redis round-trip
- Switch all chat route auth from Depends to Security for OpenAPI spec
  consistency
- Use transaction=True for Redis pipeline incrby+expire to ensure
  atomicity
2026-03-15 08:06:49 +07:00
Zamil Majdy
b9be577904 fix(backend/copilot): skip fallback token estimation for failed API requests
When a streaming error occurs before any output is produced (e.g.,
connection timeout), the fallback token estimation in the finally block
would incorrectly charge rate-limit tokens — including an artificial
minimum of 1 completion token via max(..., 1). This penalises users for
requests that produced no output.

Track whether an error occurred via _stream_error flag and skip the
fallback estimation when the request failed and no assistant text was
generated.
2026-03-14 22:46:12 +07:00
Zamil Majdy
b9951a3c53 fix: improve auth consistency, billing monitoring, and client-side navigation
- Use Security(auth.requires_user) on /usage endpoint for consistency
- Add BILLING_LEAK: prefix and TODO for atomic credit charging
- Cache credit_db() call to avoid redundant initialization
- Add test for unauthenticated /usage request returning 401
- Use router.push() instead of window.location.href for SPA navigation
- Align pytest.mark.asyncio loop_scope with existing test patterns
2026-03-14 22:26:39 +07:00
Zamil Majdy
6fb3b1e87b Merge branch 'origin/dev' into feat/tracking-cost-block
Resolved conflict in service.py: kept token usage tracking from this
branch and adopted new compact_result API from dev.
2026-03-14 21:14:03 +07:00
Zamil Majdy
760d4eeaf4 fix(frontend): add dark mode variants to UsageLimits progress bar colors
The progress bar fill colors (bg-orange-500 / bg-blue-500) were missing
dark: variants. All other colors in the component already had them.
2026-03-13 22:47:22 +07:00
Zamil Majdy
cb7d271472 refactor(backend/copilot): use credit_db() adapter pattern for credit operations
Replace inline _get_credits/_spend_credits helpers with the centralized
credit_db() accessor from db_accessors.py, consistent with workspace_db(),
chat_db(), and other DB accessors. Add module-level get_credits() and
spend_credits() to credit.py so the accessor can return the module directly.
2026-03-13 22:16:54 +07:00
Zamil Majdy
7ef530c672 Merge branch 'feat/tracking-cost-block' of github.com:Significant-Gravitas/AutoGPT into feat/tracking-cost-block 2026-03-13 22:06:00 +07:00
Zamil Majdy
699ecc8cec Merge branch 'dev' into feat/tracking-cost-block 2026-03-13 19:34:53 +07:00
Zamil Majdy
a662dc1680 fix(backend/store): bound split_camelcase input to prevent ReDoS
Add a 500-character input length limit to split_camelcase() to address
CodeQL polynomial regex warning on uncontrolled data.
2026-03-13 18:18:48 +07:00
Zamil Majdy
211be3aff1 revert: remove unnecessary TODO comment from config.py 2026-03-13 17:59:00 +07:00
Zamil Majdy
3120981e4b fix(backend/copilot): add TODO for per-user rate limit support in config 2026-03-13 17:57:42 +07:00
Zamil Majdy
73e1a5e76d fix(backend): use lazy %s formatting in logger.warning calls 2026-03-13 17:47:52 +07:00
Zamil Majdy
5966d3669d fix(platform): use round() for cache weights, accept string dates in UsageLimits 2026-03-13 17:38:25 +07:00
Zamil Majdy
d0f9ec55e4 refactor(backend): clean up block search code, add tests
- Extract _get_enabled_blocks() to deduplicate disabled/broken block
  filtering between get_missing_items() and get_stats()
- Reuse block instance from disabled check (no double instantiation)
- Move split_camelcase import to module level in hybrid_search.py
- Add tests: split_camelcase, tokenize CamelCase, disabled-block batch
  budget, get_stats with broken blocks
2026-03-13 17:30:54 +07:00
Zamil Majdy
c81ab1fc3b fix(backend/copilot): use adapter pattern for credit operations in executor
The CoPilot executor runs without a Prisma connection, so direct calls
to get_user_credit_model() caused "Client is not connected to the query
engine" errors. Replace with _get_credits/_spend_credits adapters that
fall back to RPC via DatabaseManagerAsyncClient when Prisma is unavailable.
Also add missing spend_credits/get_credits to DatabaseManagerAsyncClient.
2026-03-13 17:14:08 +07:00
Zamil Majdy
36606b206a fix(backend): add try/except to BlockHandler.get_stats() for broken blocks 2026-03-13 16:51:44 +07:00
Zamil Majdy
a8afaffd4c fix(frontend): format openapi.json with prettier 2026-03-13 16:43:57 +07:00
Zamil Majdy
a982fb8436 refactor(backend): extract shared split_camelcase helper, fix review comments
- Extract `split_camelcase()` into text_utils.py to avoid circular imports
  and share between block indexer and BM25 tokenizer (CodeRabbit nitpick)
- Fix docstring accuracy in tokenize() (CodeRabbit nitpick)
- Add exception handling for block_cls() in disabled-block filter
  to prevent a single broken block from crashing the backfill (Sentry)
- ReDoS not a concern: both regex patterns are linear-time (CodeQL)
2026-03-13 16:31:26 +07:00
Zamil Majdy
afcce75aff fix(backend): filter disabled blocks before batch_size in embedding backfill
The batch_size limit was applied before filtering out disabled blocks.
With 120+ disabled blocks and batch_size=100, the first 100 entries
were all disabled (skipped via continue), leaving the 36 enabled blocks
beyond the slice boundary never indexed. This made core blocks like
AITextGeneratorBlock, AIConversationBlock, etc. invisible to search.

Fix: filter disabled blocks from the missing list before slicing by
batch_size, so every batch slot goes to an enabled block that actually
needs indexing.
2026-03-13 16:31:26 +07:00
Zamil Majdy
70dfe64c6d fix(backend): split CamelCase block names for search indexing and BM25
Block names like "AITextGeneratorBlock" were indexed as single tokens,
making them invisible to PostgreSQL tsvector lexical search and BM25
reranking when users searched for "text generator" or "AI text". Now
CamelCase names are split into separate words before indexing
(e.g. "AI Text Generator Block") so both lexical and BM25 matching
work correctly.
2026-03-13 16:31:26 +07:00
Zamil Majdy
5446c7f18f fix(platform): consistent header icon sizing, halve rate limits
- Add size="icon" to SidebarTrigger for uniform h-9 w-9 button sizing
- Remove extra positioning wrappers (relative left-1, left-5) around header icons
- Halve daily token limit to 2.5M and weekly to 12.5M for more reasonable defaults
2026-03-13 15:55:18 +07:00
Zamil Majdy
2b0c9ba703 fix(frontend): use Button component for icon buttons, add timezone and <1% display
- UsageLimits and NotificationToggle now use <Button variant="ghost" size="icon">
  to match SidebarTrigger's padding/sizing
- Weekly reset time shows timezone abbreviation (e.g., "Mon 7:00 AM PST")
- Usage below 1% shows "<1% used" instead of "0% used" with a 1% min bar width
2026-03-13 15:42:05 +07:00
Zamil Majdy
195c7011ae fix(frontend): update usage after chat, show reset day, fix icon weight
- Invalidate usage query when stream completes so the usage bar
  updates immediately after chatting instead of waiting 30s.
- Show reset time as day/time in local timezone when over 24h away
  (e.g. "Mon 12:00 AM") instead of unclear "63h 41m".
- Use weight="light" on ChartBar icon to match other header icons.
2026-03-13 15:21:42 +07:00
Zamil Majdy
d4944fb22b fix(platform): emit StreamUsage as SSE comment, move usage to popover
StreamUsage events crashed the frontend because the Vercel AI SDK uses
z.strictObject() and rejects unknown event types. Fix by overriding
to_sse() to emit as an SSE comment (invisible to the parser). Usage
data is already recorded server-side (session DB + Redis counters).

Move usage limits from sidebar footer back to a ChartBar icon button
in the sidebar header that opens a popover on click.
2026-03-13 15:14:14 +07:00
Zamil Majdy
a5ed8fefa9 feat(copilot): cost-weighted token rate limiting with cache breakdown
- Rate limiter now uses Anthropic's cost model: cache_read at 10%,
  cache_creation at 25%, uncached and output at 100%
- Track cache_read_tokens and cache_creation_tokens separately in
  Usage model, StreamUsage response, and SDK token extraction
- Pass cache breakdown through to record_token_usage() for accurate
  weighted counting
- Add test for cost-weighted counting (10K cache_read → 1K weighted)

This makes multi-turn conversations fairer: cached system prompts
and tool schemas don't penalize users at full token cost.
2026-03-13 14:36:04 +07:00
Zamil Majdy
a52a777b29 fix(copilot): increase rate limits from 500K/5M to 5M/25M daily/weekly
A single CoPilot turn consumes ~10-15K tokens (system prompt + tool
schemas), so 500K daily only allowed ~35-50 turns which is too
restrictive for normal use.
2026-03-13 14:21:01 +07:00
Zamil Majdy
8bec7a6933 fix(frontend): move usage limits to left column on billing page
Move CoPilot Usage Limits below Automatic Refill Settings in the left
column and add "Open CoPilot" button for consistency with other sections.
2026-03-13 14:17:57 +07:00
Zamil Majdy
e73791efed fix(copilot): move usage limits to sidebar bottom, add to billing page
- Move UsageLimits from top-right headerSlot to left sidebar footer
- Show usage limits on both main copilot page and inside chat sessions
- Add CoPilot Usage Limits section to billing/credits page
- Change "Learn more about usage limits" to "Manage billing & credits"
- Remove unused headerSlot prop from ChatContainer/ChatMessagesContainer
- Clean up unused imports in CopilotPage
2026-03-13 14:08:57 +07:00
Zamil Majdy
2d161ce2b9 refactor(platform): replace per-session rate limits with daily fixed-window
Per-session limits were gameable (create new session to reset) and had
confusing reset semantics (12h inactivity TTL that refreshed on use).
Replace with daily fixed-window counter that resets at midnight UTC,
matching the weekly window pattern.

- session_token_limit → daily_token_limit (500K tokens/day)
- Redis key: copilot:usage:daily:{user_id}:{YYYY-MM-DD} with TTL
- Remove session_id from usage API endpoint and record_token_usage()
- Frontend: "Current session" → "Today", "Weekly limits" → "This week"
2026-03-13 13:22:59 +07:00
Zamil Majdy
6fc4989654 fix(copilot): include full context window in fallback token estimation
Each API call sends the complete openai_messages list (system prompt +
history + turn), so the fallback estimator should count all of it to
match what prompt_tokens would have reported.
2026-03-13 12:36:51 +07:00
Zamil Majdy
976443bf6e refactor(copilot): use tiktoken for fallback token estimation
Replace rough chars/4 heuristic with proper tiktoken tokenizer via
estimate_token_count/estimate_token_count_str from backend.util.prompt.
2026-03-13 05:24:53 +07:00
Zamil Majdy
4ceb15b3f1 fix(copilot): scope fallback token estimation to current turn only
The fallback estimator was counting the entire openai_messages history
(system prompt + all previous turns) instead of just the messages added
during the current turn. This caused overcounting and overly strict
rate limiting when providers don't return streaming usage data.
2026-03-13 03:44:30 +07:00
Zamil Majdy
3096f94996 feat(copilot): set default rate limits based on observed usage data
Session: 500K tokens (P90 session usage ~550K)
Weekly: 5M tokens (~10x heaviest observed weekly user)
2026-03-13 03:34:05 +07:00
Zamil Majdy
6f90729612 merge: resolve conflicts with dev (coerce_inputs_to_schema)
Merge dev's coerce_inputs_to_schema into execute_block alongside
credit charging. Both features coexist: coercion runs before
credit check. Test file combines both test suites.
2026-03-13 03:15:37 +07:00
Zamil Majdy
ebf89dde8b fix(copilot): include cached tokens in SDK token tracking
Anthropic's API reports cached tokens separately (cache_read_input_tokens,
cache_creation_input_tokens) from input_tokens. The previous code only read
input_tokens, undercounting total tokens for rate limiting.

OpenRouter (baseline path) already includes cached tokens in prompt_tokens
per OpenAI-compatible format — added clarifying comment.
2026-03-13 02:46:57 +07:00
Zamil Majdy
5d057e97e5 fix(copilot): move StreamUsage/StreamFinish back out of finally block
PEP 525 prohibits yielding from finally in async generators during
aclose() — doing so raises RuntimeError on client disconnect. Move
yields after try/finally where they work on normal completion and are
harmlessly unreachable on GeneratorExit.
2026-03-13 00:40:58 +07:00
Zamil Majdy
1d2f641a26 fix(copilot): move session.usage.append to finally block in SDK service
Ensures session usage persistence and rate-limit recording stay
consistent even when an exception interrupts the try block. Mirrors
the baseline service pattern.
2026-03-13 00:35:09 +07:00
Zamil Majdy
dcb71ab0b9 test(platform): add tests for usage endpoint and UsageLimits component
- 3 new tests for GET /usage endpoint (with/without session_id, config limits)
- 9 new tests for UsageLimits frontend component (loading, empty, rendering,
  percentage capping, session-only/weekly-only, link, hook args)
2026-03-13 00:30:52 +07:00
Zamil Majdy
8136b90860 fix(copilot): move StreamUsage/StreamFinish yields into finally block
On client disconnect, GeneratorExit terminates the generator after the
finally block, making yields after it unreachable. Moving them inside
finally ensures they are at least attempted.
2026-03-13 00:05:41 +07:00
Zamil Majdy
4d179a7c37 Merge branch 'dev' into feat/tracking-cost-block 2026-03-12 23:56:38 +07:00
Zamil Majdy
f78adcdc65 fix(platform): address all open review items — clock skew, token recording, generated hooks
- Fix clock skew: share single `now` timestamp across `_weekly_key` and
  `_weekly_reset_time` calls to prevent ISO week boundary race condition
- SDK: move `record_token_usage` to finally block so tokens are always
  recorded even when exceptions interrupt the stream (prevents rate limit bypass)
- Baseline: wrap `record_token_usage` in try/except so it cannot block
  session persistence
- Routes: pass `session_id=None` instead of empty string to avoid
  malformed Redis keys; `get_usage_status` now skips session lookup when
  session_id is None
- Tests: add partial None counter tests and no-session-id test
- Frontend: replace raw fetch with generated Orval hook
  (`useGetV2GetCopilotUsage`), use generated `CoPilotUsageStatus` type,
  fix `Date` vs `string` type for `resets_at`
2026-03-12 23:18:31 +07:00
Zamil Majdy
40388b7520 fix(platform): address review — orphan keys, TTL gather, dark mode, lazy logging
- Guard record_token_usage with user_id check to prevent orphan Redis keys
- Fold session TTL lookup into asyncio.gather (eliminate serial round-trip)
- Add dark: variants to UsageLimits component
- Use lazy %s formatting in logger calls instead of eager f-strings
2026-03-12 23:03:36 +07:00
Zamil Majdy
dd7be1158b test(backend): expand rate limit test coverage, fix token estimation null safety
- Add tests for: past resets_at (negative time), expired TTL fallback,
  _session_reset_from_ttl Redis error, pipeline TTL assertions,
  pipeline execute RedisError
- Fix null safety in baseline token estimation (m.get("content") or "")
- Use MagicMock for pipeline sync methods to eliminate coroutine warnings
2026-03-12 22:38:31 +07:00
Zamil Majdy
c0e59f0a6b fix(backend): address pushback on credit charging and token estimation
- Post-exec InsufficientBalanceError: return output + log warning
  instead of ErrorResponse (block already executed with side effects)
- Add token estimation fallback for providers that don't support
  stream_options include_usage (1 token ≈ 4 chars)
- Remove unnecessary # type: ignore on AsyncRedis param
2026-03-12 22:25:14 +07:00
Zamil Majdy
104d1f1bf4 fix(backend): revert to += for prompt tokens, add negative time guard
- Revert prompt token tracking back to += (sum all rounds): each API
  call bills independently, so total billed tokens is the correct
  metric for rate limiting
- Guard against negative reset time in RateLimitExceeded message
  (clock skew / stale Redis TTL)
2026-03-12 22:09:05 +07:00
Zamil Majdy
d9e9cd4c98 fix(backend): address human + Sentry review feedback on CoPilot tracking
- Fix double-counting of prompt tokens in multi-round tool calls:
  use = (not +=) since each API call's prompt_tokens includes full
  conversation history (both baseline and SDK paths)
- Fix docstring: "sliding-window" → "fixed-window" counters
- Type redis param as AsyncRedis instead of object
- Use pipeline(transaction=False) for independent INCRBY+EXPIRE
- Simplify days_until_monday calculation with `or 7`
- Flatten time_str into ternary, merge f-strings
- Parallelize Redis gets with asyncio.gather
- Document session_id=None behavior in usage endpoint
2026-03-12 21:54:31 +07:00
Zamil Majdy
ca416300ec fix(backend): address second round CodeRabbit review feedback
- Narrow except blocks to (RedisError, ConnectionError, OSError) instead
  of bare Exception to avoid hiding coding bugs
- Remove raw user_id from log messages to prevent PII leaks under Redis
  outages
- Reuse credit_model instead of fetching twice in execute_block()
- Treat post-execution InsufficientBalanceError as fatal (matches
  executor behavior) instead of silently swallowing it
2026-03-12 21:37:32 +07:00
Zamil Majdy
c589cd0c43 fix(backend): address CodeRabbit review feedback
- Derive session reset time from Redis TTL instead of hardcoded 3h
- Add description to UsageWindow.limit documenting 0 = unlimited
- Compare balance < cost instead of balance <= 0 in pre-exec check
- Document TOCTOU behavior in check_rate_limit docstring
2026-03-12 21:10:52 +07:00
Zamil Majdy
b6d863fcd2 feat(platform): add CoPilot credit charging, token tracking, and rate limiting
- Charge credits for block execution in CoPilot (matching graph executor behavior)
- Track LLM token usage for both SDK (Claude) and baseline (OpenAI) paths
- Add Redis-based per-user rate limiting with session and weekly token windows
- Expose usage status via GET /api/chat/usage endpoint
- Add frontend UsageLimits component with progress bars in CoPilot header
- Include unit tests for rate limiting and block credit charging
2026-03-12 20:50:21 +07:00
44 changed files with 2095 additions and 4986 deletions

View File

@@ -178,16 +178,6 @@ yield "image_url", result_url
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware

View File

@@ -8,7 +8,7 @@ from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
@@ -27,6 +27,12 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -120,6 +126,8 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -207,7 +215,7 @@ async def list_sessions(
}
except Exception:
logger.warning(
"Failed to fetch processing status from Redis; " "defaulting to empty"
"Failed to fetch processing status from Redis; defaulting to empty"
)
return ListSessionsResponse(
@@ -229,7 +237,7 @@ async def list_sessions(
"/sessions",
)
async def create_session(
user_id: Annotated[str, Depends(auth.get_user_id)],
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -348,7 +356,7 @@ async def update_session_title_route(
)
async def get_session(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
@@ -389,6 +397,10 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -396,6 +408,25 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get(
"/usage",
)
async def get_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -405,7 +436,7 @@ async def get_session(
)
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
@@ -450,7 +481,7 @@ async def cancel_session_task(
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str | None = Depends(auth.get_user_id),
user_id: str = Security(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
@@ -467,7 +498,7 @@ async def stream_chat_post(
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
user_id: Authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
@@ -476,9 +507,7 @@ async def stream_chat_post(
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
@@ -496,6 +525,18 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based).
# check_rate_limit short-circuits internally when both limits are 0.
if user_id:
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
@@ -730,7 +771,7 @@ async def stream_chat_post(
)
async def resume_session_stream(
session_id: str,
user_id: str | None = Depends(auth.get_user_id),
user_id: str = Security(auth.get_user_id),
):
"""
Resume an active stream for a session.

View File

@@ -1,5 +1,6 @@
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import fastapi
@@ -251,6 +252,156 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["isDeleted"] is False
# ─── Rate limit → 429 ─────────────────────────────────────────────────
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
# Ensure the rate-limit branch is entered by setting a non-zero limit.
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
assert "daily" in response.json()["detail"].lower()
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
resets_at = datetime.now(UTC) + timedelta(days=3)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("weekly", resets_at),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"].lower()
assert "weekly" in detail
assert "resets in" in detail
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
"""The 429 response detail should include the human-readable reset time."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded(
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"]
assert "2h" in detail
assert "Resets in" in detail
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
def test_usage_rejects_unauthenticated_request() -> None:
"""GET /usage should return 401 when no valid JWT is provided."""
unauthenticated_app = fastapi.FastAPI()
unauthenticated_app.include_router(chat_routes.router)
unauthenticated_client = fastapi.testclient.TestClient(unauthenticated_app)
response = unauthenticated_client.get("/usage")
assert response.status_code == 401
# ─── Suggested prompts endpoint ──────────────────────────────────────

View File

@@ -36,6 +36,7 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
@@ -43,6 +44,7 @@ from backend.copilot.service import (
_get_openai_client,
config,
)
from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
@@ -221,6 +223,10 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -232,6 +238,7 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
@@ -242,6 +249,20 @@ async def stream_chat_completion_baseline(
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
@@ -394,6 +415,7 @@ async def stream_chat_completion_baseline(
)
except Exception as e:
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
@@ -411,6 +433,49 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
# NOTE: This estimates one round's prompt tokens. Multi-round tool-calling
# turns consume prompt tokens on each API call, so the total is underestimated.
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
):
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
# cannot break out cache_read/cache_creation weights. Users on the
# baseline path may be slightly over-counted vs the SDK path.
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -421,4 +486,16 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -70,6 +70,27 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Per-turn token cost varies with context size: ~10-15K for early turns,
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
# allows ~70-100 turns/day.
# Checked at the HTTP layer (routes.py) before each turn.
#
# TODO: These are deploy-time constants applied identically to every user.
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
# must move to the database (e.g., a UserPlan table) and get_usage_status /
# check_rate_limit would look up each user's specific limits instead of
# reading config.daily_token_limit / config.weekly_token_limit.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,

View File

@@ -73,6 +73,9 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):
@@ -98,7 +101,10 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Calculate usage from token counts
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
# completion totals. This is a known limitation.
usage = []
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
usage.append(

View File

@@ -0,0 +1,266 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_USAGE_KEY_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
daily_used = 0
weekly_used = 0
try:
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
daily_used = int(daily_raw or 0)
weekly_used = int(weekly_raw or 0)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for usage status, returning zeros")
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
# round-trip entirely.
if daily_token_limit <= 0 and weekly_token_limit <= 0:
return
now = datetime.now(UTC)
try:
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
daily_used = int(daily_raw or 0)
weekly_used = int(weekly_raw or 0)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for rate limit check, allowing request")
return
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
# transaction=False: these are independent INCRBY+EXPIRE pairs on
# separate keys — no cross-key atomicity needed. Skipping
# MULTI/EXEC avoids the overhead. If the connection drops between
# INCRBY and EXPIRE the key survives until the next date-based key
# rotation (daily/weekly), so the memory-leak risk is negligible.
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except (RedisError, ConnectionError, OSError):
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_USAGE_KEY_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_USAGE_KEY_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC)."""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)

View File

@@ -0,0 +1,334 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,12 +186,43 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics."""
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
prompt_tokens: int = Field(
...,
serialization_alias="promptTokens",
description="Number of uncached prompt tokens",
)
completion_tokens: int = Field(
...,
serialization_alias="completionTokens",
description="Number of completion tokens",
)
total_tokens: int = Field(
...,
serialization_alias="totalTokens",
description="Total number of tokens (raw, not weighted)",
)
cache_read_tokens: int = Field(
default=0,
serialization_alias="cacheReadTokens",
description="Prompt tokens served from cache (10% cost)",
)
cache_creation_tokens: int = Field(
default=0,
serialization_alias="cacheCreationTokens",
description="Prompt tokens written to cache (25% cost)",
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True, by_alias=True)}\n\n"
class StreamError(StreamBaseResponse):

View File

@@ -55,12 +55,14 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
_generate_session_title,
_is_langfuse_configured,
)
from ..token_tracking import persist_and_record_usage
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
@@ -736,6 +738,13 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
turn_cost_usd: float | None = None
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -1112,7 +1121,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
# Log ResultMessage details and capture token usage
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1131,6 +1140,33 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with the
# CLI's active context so they stay identical.
@@ -1347,6 +1383,26 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
# total_tokens = prompt (uncached input) + completion (output).
# Cache tokens are tracked separately and excluded from total
# so that the semantics match the baseline path (OpenRouter)
# which folds cache into prompt_tokens. Keeping total_tokens
# = prompt + completion everywhere makes cross-path comparisons
# and session-level aggregation consistent.
total_tokens = turn_prompt_tokens + turn_completion_tokens
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1411,6 +1467,20 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).

View File

@@ -0,0 +1,93 @@
"""Shared token-usage persistence and rate-limit recording.
Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
1. Append a ``Usage`` record to the session.
2. Log the turn's token counts.
3. Record weighted usage in Redis for rate-limiting.
This module extracts that common logic so both paths stay in sync.
"""
import logging
from .model import ChatSession, Usage
from .rate_limit import record_token_usage
logger = logging.getLogger(__name__)
async def persist_and_record_usage(
*,
session: ChatSession | None,
user_id: str | None,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
log_prefix: str = "",
cost_usd: float | str | None = None,
) -> int:
"""Persist token usage to session and record for rate limiting.
Args:
session: The chat session to append usage to (may be None on error).
user_id: User ID for rate-limit counters (skipped if None).
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (Anthropic only).
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
Returns:
The computed total_tokens (prompt + completion; cache excluded).
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
if prompt_tokens <= 0 and completion_tokens <= 0:
return 0
# total_tokens = prompt + completion. Cache tokens are tracked
# separately and excluded from total so both baseline and SDK
# paths share the same semantics.
total_tokens = prompt_tokens + completion_tokens
if session is not None:
session.usage.append(
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
)
if cache_read_tokens or cache_creation_tokens:
logger.info(
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
)
else:
logger.info(
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
f"completion={completion_tokens}, total={total_tokens}"
)
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
return total_tokens

View File

@@ -0,0 +1,281 @@
"""Unit tests for token_tracking.persist_and_record_usage.
Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
calling conventions, session persistence, and rate-limit recording.
"""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from .model import ChatSession, Usage
from .token_tracking import persist_and_record_usage
def _make_session() -> ChatSession:
"""Return a minimal in-memory ChatSession for testing."""
return ChatSession(
session_id="sess-test",
user_id="user-test",
title=None,
messages=[],
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
# ---------------------------------------------------------------------------
# Return value / total_tokens semantics
# ---------------------------------------------------------------------------
class TestTotalTokens:
@pytest.mark.asyncio
async def test_returns_prompt_plus_completion(self):
"""total_tokens = prompt + completion (cache excluded from total)."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=300,
completion_tokens=200,
)
assert total == 500
@pytest.mark.asyncio
async def test_returns_zero_when_no_tokens(self):
"""Returns 0 early when both prompt and completion are zero."""
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=0,
completion_tokens=0,
)
assert total == 0
@pytest.mark.asyncio
async def test_cache_tokens_excluded_from_total(self):
"""Cache tokens are stored separately and not added to total_tokens."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=5000,
cache_creation_tokens=200,
)
# total = prompt + completion only (5000 + 200 cache excluded)
assert total == 150
@pytest.mark.asyncio
async def test_baseline_path_no_cache(self):
"""Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id="u1",
prompt_tokens=1000,
completion_tokens=400,
log_prefix="[Baseline]",
)
assert total == 1400
@pytest.mark.asyncio
async def test_sdk_path_with_cache(self):
"""SDK (Anthropic) path passes cache tokens; total still = prompt + completion."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id="u2",
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=8000,
cache_creation_tokens=400,
log_prefix="[SDK]",
cost_usd=0.0015,
)
assert total == 300
# ---------------------------------------------------------------------------
# Session persistence
# ---------------------------------------------------------------------------
class TestSessionPersistence:
@pytest.mark.asyncio
async def test_appends_usage_to_session(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
session=session,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
assert len(session.usage) == 1
usage: Usage = session.usage[0]
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.cache_read_tokens == 0
assert usage.cache_creation_tokens == 0
@pytest.mark.asyncio
async def test_appends_cache_breakdown_to_session(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
session=session,
user_id=None,
prompt_tokens=200,
completion_tokens=80,
cache_read_tokens=3000,
cache_creation_tokens=500,
)
usage: Usage = session.usage[0]
assert usage.cache_read_tokens == 3000
assert usage.cache_creation_tokens == 500
@pytest.mark.asyncio
async def test_multiple_turns_append_multiple_records(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
session=session, user_id=None, prompt_tokens=100, completion_tokens=50
)
await persist_and_record_usage(
session=session, user_id=None, prompt_tokens=200, completion_tokens=70
)
assert len(session.usage) == 2
@pytest.mark.asyncio
async def test_none_session_does_not_raise(self):
"""When session is None (e.g. error path), no exception should be raised."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
assert total == 150
@pytest.mark.asyncio
async def test_no_append_when_zero_tokens(self):
"""When tokens are zero, function returns early — session unchanged."""
session = _make_session()
total = await persist_and_record_usage(
session=session,
user_id=None,
prompt_tokens=0,
completion_tokens=0,
)
assert total == 0
assert len(session.usage) == 0
# ---------------------------------------------------------------------------
# Rate-limit recording
# ---------------------------------------------------------------------------
class TestRateLimitRecording:
@pytest.mark.asyncio
async def test_calls_record_token_usage_when_user_id_present(self):
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
await persist_and_record_usage(
session=None,
user_id="user-abc",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
)
mock_record.assert_awaited_once_with(
user_id="user-abc",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
)
@pytest.mark.asyncio
async def test_skips_record_when_user_id_is_none(self):
"""Anonymous sessions should not create Redis keys."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
mock_record.assert_not_awaited()
@pytest.mark.asyncio
async def test_record_failure_does_not_raise(self):
"""A Redis error in record_token_usage should be swallowed (fail-open)."""
mock_record = AsyncMock(side_effect=ConnectionError("Redis down"))
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
# Should not raise
total = await persist_and_record_usage(
session=None,
user_id="user-xyz",
prompt_tokens=100,
completion_tokens=50,
)
assert total == 150
@pytest.mark.asyncio
async def test_skips_record_when_zero_tokens(self):
"""Returns 0 before calling record_token_usage when tokens are zero."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
await persist_and_record_usage(
session=None,
user_id="user-abc",
prompt_tokens=0,
completion_tokens=0,
)
mock_record.assert_not_awaited()

View File

@@ -8,11 +8,13 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data.db_accessors import workspace_db
from backend.data.credit import UsageTransactionMetadata
from backend.data.db_accessors import credit_db, workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.type import coerce_inputs_to_schema
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
@@ -115,6 +117,21 @@ async def execute_block(
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(input_data, block.input_schema)
# Pre-execution credit check (courtesy; spend_credits is atomic)
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
_credit_db = credit_db()
if has_cost:
balance = await _credit_db.get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
@@ -123,6 +140,51 @@ async def execute_block(
):
outputs[output_name].append(output_data)
# Charge credits for block execution
if has_cost:
try:
await _credit_db.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except Exception as e:
# Block already executed (with possible side effects). Never
# return ErrorResponse here — the user received output and
# deserves it. Log the billing failure for reconciliation.
leak_type = (
"INSUFFICIENT_BALANCE"
if isinstance(e, InsufficientBalanceError)
else "UNEXPECTED_ERROR"
)
logger.error(
"BILLING_LEAK[%s]: block executed but credit charge failed — "
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
leak_type,
user_id,
block_id,
node_exec_id,
cost,
e,
extra={
"json_fields": {
"billing_leak": True,
"leak_type": leak_type,
"user_id": user_id,
"cost": str(cost),
}
},
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
@@ -133,14 +195,14 @@ async def execute_block(
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
logger.warning("Block execution failed: %s", e)
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
logger.error("Unexpected error executing block: %s", e, exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
error=str(e),

View File

@@ -1,18 +1,197 @@
"""Tests for execute_block type coercion in helpers.py.
Verifies that execute_block() coerces string input values to match the block's
expected input types, mirroring the executor's validate_exec() logic.
This is critical for @@agptfile: expansion, where file content is always a string
but the block may expect structured types (e.g. list[list[str]]).
"""
"""Tests for execute_block — credit charging and type coercion."""
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
_USER = "test-user-helpers"
_SESSION = "test-session-helpers"
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
"""Create a minimal mock block for execute_block()."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = BlockType.STANDARD
mock.input_schema = MagicMock()
mock.input_schema.get_credentials_fields_info.return_value = {}
async def _execute(
input_data: dict, **kwargs: Any
) -> AsyncIterator[tuple[str, Any]]:
yield "result", "ok"
mock.execute = _execute
return mock
def _patch_workspace():
"""Patch workspace_db to return a mock workspace."""
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_ws_db = MagicMock()
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
def _patch_credit_db(
get_credits_return: int = 100,
spend_credits_side_effect: Any = None,
):
"""Patch credit_db accessor to return a mock credit adapter."""
mock_credit = MagicMock()
mock_credit.get_credits = AsyncMock(return_value=get_credits_return)
if spend_credits_side_effect is not None:
mock_credit.spend_credits = AsyncMock(side_effect=spend_credits_side_effect)
else:
mock_credit.spend_credits = AsyncMock()
return (
patch(
"backend.copilot.tools.helpers.credit_db",
return_value=mock_credit,
),
mock_credit,
)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits after execution."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(get_credits_return=100)
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={"text": "hello"},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_credit.spend_credits.assert_awaited_once()
call_kwargs = mock_credit.spend_credits.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(get_credits_return=5)
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_credit.get_credits.assert_not_awaited()
mock_credit.spend_credits.assert_not_awaited()
async def test_returns_output_on_post_exec_insufficient_balance(self):
"""If charging fails after execution, output is still returned (block already ran)."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(
get_credits_return=15,
spend_credits_side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
),
)
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
# Block already executed (with side effects), so output is returned
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# ---------------------------------------------------------------------------
# Type coercion tests
# ---------------------------------------------------------------------------
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
@@ -28,7 +207,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
return schema
def _make_block(
def _make_coerce_block(
block_id: str,
name: str,
annotations: dict[str, Any],
@@ -60,7 +239,7 @@ _TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
block = _make_block(
block = _make_coerce_block(
"sheets-write",
"Google Sheets Write",
{"values": list[list[str]], "spreadsheet_id": str},
@@ -103,7 +282,7 @@ async def test_coerce_json_string_to_nested_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string → list[str]."""
block = _make_block(
block = _make_coerce_block(
"list-block",
"List Block",
{"items": list[str]},
@@ -135,7 +314,7 @@ async def test_coerce_json_string_to_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string → dict[str, str]."""
block = _make_block(
block = _make_coerce_block(
"dict-block",
"Dict Block",
{"config": dict[str, str]},
@@ -167,7 +346,7 @@ async def test_coerce_json_string_to_dict():
@pytest.mark.asyncio(loop_scope="session")
async def test_no_coercion_when_type_matches():
"""Already-correct types pass through without coercion."""
block = _make_block(
block = _make_coerce_block(
"pass-through",
"Pass Through",
{"values": list[list[str]], "name": str},
@@ -201,7 +380,7 @@ async def test_no_coercion_when_type_matches():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number → int."""
block = _make_block(
block = _make_coerce_block(
"int-block",
"Int Block",
{"count": int},
@@ -234,7 +413,7 @@ async def test_coerce_string_to_int():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_skips_none_values():
"""None values are not coerced (they may be optional fields)."""
block = _make_block(
block = _make_coerce_block(
"optional-block",
"Optional Block",
{"data": list[str], "label": str},
@@ -267,7 +446,7 @@ async def test_coerce_skips_none_values():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_union_type_preserves_valid_member():
"""Union-typed fields should not be coerced when the value matches a member."""
block = _make_block(
block = _make_coerce_block(
"union-block",
"Union Block",
{"content": str | list[str]},
@@ -301,7 +480,7 @@ async def test_coerce_union_type_preserves_valid_member():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_inner_elements_of_generic():
"""Inner elements of generic containers are recursively coerced."""
block = _make_block(
block = _make_coerce_block(
"inner-coerce",
"Inner Coerce",
{"values": list[str]},

View File

@@ -129,3 +129,16 @@ def review_db():
review_db = get_database_manager_async_client()
return review_db
def credit_db():
if db.is_connected():
from backend.data import db_manager as _credit_db
credit_db = _credit_db
else:
from backend.util.clients import get_database_manager_async_client
credit_db = get_database_manager_async_client()
return credit_db

View File

@@ -148,6 +148,11 @@ async def _get_credits(user_id: str) -> int:
return await user_credit_model.get_credits(user_id)
# Public aliases used by db_accessors.credit_db() when Prisma is connected
get_credits = _get_credits
spend_credits = _spend_credits
class DatabaseManager(AppService):
"""Database connection pooling service.
@@ -512,6 +517,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Credits ============ #
spend_credits = d.spend_credits
get_credits = d.get_credits
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding

View File

@@ -183,8 +183,7 @@ class WorkspaceManager:
f"{Config().max_file_size_mb}MB limit"
)
# Scan here — callers must NOT duplicate this scan.
# WorkspaceManager owns virus scanning for all persisted files.
# Virus scan content before persisting (defense in depth)
await scan_content_safe(content, filename=filename)
# Determine path with session scoping

View File

@@ -1,440 +0,0 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { screen, cleanup } from "@testing-library/react";
import { render } from "@/tests/integrations/test-utils";
import React from "react";
import { BlockUIType } from "../components/types";
import type {
CustomNodeData,
CustomNode as CustomNodeType,
} from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import type { NodeProps } from "@xyflow/react";
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
// ---- Mock sub-components ----
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeContainer",
() => ({
NodeContainer: ({
children,
hasErrors,
}: {
children: React.ReactNode;
hasErrors: boolean;
}) => (
<div data-testid="node-container" data-has-errors={String(!!hasErrors)}>
{children}
</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader",
() => ({
NodeHeader: ({ data }: { data: CustomNodeData }) => (
<div data-testid="node-header">{data.title}</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/StickyNoteBlock",
() => ({
StickyNoteBlock: ({ data }: { data: CustomNodeData }) => (
<div data-testid="sticky-note-block">{data.title}</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeAdvancedToggle",
() => ({
NodeAdvancedToggle: () => <div data-testid="node-advanced-toggle" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/NodeOutput",
() => ({
NodeDataRenderer: () => <div data-testid="node-data-renderer" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeExecutionBadge",
() => ({
NodeExecutionBadge: () => <div data-testid="node-execution-badge" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeRightClickMenu",
() => ({
NodeRightClickMenu: ({ children }: { children: React.ReactNode }) => (
<div data-testid="node-right-click-menu">{children}</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/WebhookDisclaimer",
() => ({
WebhookDisclaimer: () => <div data-testid="webhook-disclaimer" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/SubAgentUpdate/SubAgentUpdateFeature",
() => ({
SubAgentUpdateFeature: () => <div data-testid="sub-agent-update" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/AyrshareConnectButton",
() => ({
AyrshareConnectButton: () => <div data-testid="ayrshare-connect-button" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/FormCreator",
() => ({
FormCreator: () => <div data-testid="form-creator" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/OutputHandler",
() => ({
OutputHandler: () => <div data-testid="output-handler" />,
}),
);
vi.mock(
"@/components/renderers/InputRenderer/utils/input-schema-pre-processor",
() => ({
preprocessInputSchema: (schema: unknown) => schema,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode",
() => ({
useCustomNode: ({ data }: { data: CustomNodeData }) => ({
inputSchema: data.inputSchema,
outputSchema: data.outputSchema,
isMCPWithTool: false,
}),
}),
);
vi.mock("@xyflow/react", async () => {
const actual = await vi.importActual("@xyflow/react");
return {
...actual,
useReactFlow: () => ({
getNodes: () => [],
getEdges: () => [],
setNodes: vi.fn(),
setEdges: vi.fn(),
getNode: vi.fn(),
}),
useNodeId: () => "test-node-id",
useUpdateNodeInternals: () => vi.fn(),
Handle: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
Position: { Left: "left", Right: "right", Top: "top", Bottom: "bottom" },
};
});
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
// ---- Helpers ----
function buildNodeData(
overrides: Partial<CustomNodeData> = {},
): CustomNodeData {
return {
hardcodedValues: {},
title: "Test Block",
description: "A test block",
inputSchema: { type: "object", properties: {} },
outputSchema: { type: "object", properties: {} },
uiType: BlockUIType.STANDARD,
block_id: "block-123",
costs: [],
categories: [],
...overrides,
};
}
function buildNodeProps(
dataOverrides: Partial<CustomNodeData> = {},
propsOverrides: Partial<NodeProps<CustomNodeType>> = {},
): NodeProps<CustomNodeType> {
return {
id: "node-1",
data: buildNodeData(dataOverrides),
selected: false,
type: "custom",
isConnectable: true,
positionAbsoluteX: 0,
positionAbsoluteY: 0,
zIndex: 0,
dragging: false,
dragHandle: undefined,
draggable: true,
selectable: true,
deletable: true,
parentId: undefined,
width: undefined,
height: undefined,
sourcePosition: undefined,
targetPosition: undefined,
...propsOverrides,
};
}
function renderCustomNode(
dataOverrides: Partial<CustomNodeData> = {},
propsOverrides: Partial<NodeProps<CustomNodeType>> = {},
) {
const props = buildNodeProps(dataOverrides, propsOverrides);
return render(<CustomNode {...props} />);
}
function createExecutionResult(
overrides: Partial<NodeExecutionResult> = {},
): NodeExecutionResult {
return {
node_exec_id: overrides.node_exec_id ?? "exec-1",
node_id: overrides.node_id ?? "node-1",
graph_exec_id: overrides.graph_exec_id ?? "graph-exec-1",
graph_id: overrides.graph_id ?? "graph-1",
graph_version: overrides.graph_version ?? 1,
user_id: overrides.user_id ?? "test-user",
block_id: overrides.block_id ?? "block-1",
status: overrides.status ?? "COMPLETED",
input_data: overrides.input_data ?? {},
output_data: overrides.output_data ?? {},
add_time: overrides.add_time ?? new Date("2024-01-01T00:00:00Z"),
queue_time: overrides.queue_time ?? new Date("2024-01-01T00:00:00Z"),
start_time: overrides.start_time ?? new Date("2024-01-01T00:00:01Z"),
end_time: overrides.end_time ?? new Date("2024-01-01T00:00:02Z"),
};
}
// ---- Tests ----
beforeEach(() => {
cleanup();
});
describe("CustomNode", () => {
describe("STANDARD type rendering", () => {
it("renders NodeHeader with the block title", () => {
renderCustomNode({ title: "My Standard Block" });
const header = screen.getByTestId("node-header");
expect(header).toBeDefined();
expect(header.textContent).toContain("My Standard Block");
});
it("renders NodeContainer, FormCreator, OutputHandler, and NodeExecutionBadge", () => {
renderCustomNode();
expect(screen.getByTestId("node-container")).toBeDefined();
expect(screen.getByTestId("form-creator")).toBeDefined();
expect(screen.getByTestId("output-handler")).toBeDefined();
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
expect(screen.getByTestId("node-data-renderer")).toBeDefined();
expect(screen.getByTestId("node-advanced-toggle")).toBeDefined();
});
it("wraps content in NodeRightClickMenu", () => {
renderCustomNode();
expect(screen.getByTestId("node-right-click-menu")).toBeDefined();
});
it("does not render StickyNoteBlock for STANDARD type", () => {
renderCustomNode();
expect(screen.queryByTestId("sticky-note-block")).toBeNull();
});
});
describe("NOTE type rendering", () => {
it("renders StickyNoteBlock instead of main UI", () => {
renderCustomNode({ uiType: BlockUIType.NOTE, title: "My Note" });
const note = screen.getByTestId("sticky-note-block");
expect(note).toBeDefined();
expect(note.textContent).toContain("My Note");
});
it("does not render NodeContainer or other standard components", () => {
renderCustomNode({ uiType: BlockUIType.NOTE });
expect(screen.queryByTestId("node-container")).toBeNull();
expect(screen.queryByTestId("node-header")).toBeNull();
expect(screen.queryByTestId("form-creator")).toBeNull();
expect(screen.queryByTestId("output-handler")).toBeNull();
});
});
describe("WEBHOOK type rendering", () => {
it("renders WebhookDisclaimer for WEBHOOK type", () => {
renderCustomNode({ uiType: BlockUIType.WEBHOOK });
expect(screen.getByTestId("webhook-disclaimer")).toBeDefined();
});
it("renders WebhookDisclaimer for WEBHOOK_MANUAL type", () => {
renderCustomNode({ uiType: BlockUIType.WEBHOOK_MANUAL });
expect(screen.getByTestId("webhook-disclaimer")).toBeDefined();
});
});
describe("AGENT type rendering", () => {
it("renders SubAgentUpdateFeature for AGENT type", () => {
renderCustomNode({ uiType: BlockUIType.AGENT });
expect(screen.getByTestId("sub-agent-update")).toBeDefined();
});
it("does not render SubAgentUpdateFeature for non-AGENT types", () => {
renderCustomNode({ uiType: BlockUIType.STANDARD });
expect(screen.queryByTestId("sub-agent-update")).toBeNull();
});
});
describe("OUTPUT type rendering", () => {
it("does not render OutputHandler for OUTPUT type", () => {
renderCustomNode({ uiType: BlockUIType.OUTPUT });
expect(screen.queryByTestId("output-handler")).toBeNull();
});
it("still renders FormCreator and other components for OUTPUT type", () => {
renderCustomNode({ uiType: BlockUIType.OUTPUT });
expect(screen.getByTestId("form-creator")).toBeDefined();
expect(screen.getByTestId("node-header")).toBeDefined();
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
});
describe("AYRSHARE type rendering", () => {
it("renders AyrshareConnectButton for AYRSHARE type", () => {
renderCustomNode({ uiType: BlockUIType.AYRSHARE });
expect(screen.getByTestId("ayrshare-connect-button")).toBeDefined();
});
it("does not render AyrshareConnectButton for non-AYRSHARE types", () => {
renderCustomNode({ uiType: BlockUIType.STANDARD });
expect(screen.queryByTestId("ayrshare-connect-button")).toBeNull();
});
});
describe("error states", () => {
it("sets hasErrors on NodeContainer when data.errors has non-empty values", () => {
renderCustomNode({
errors: { field1: "This field is required" },
});
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("true");
});
it("does not set hasErrors when data.errors is empty", () => {
renderCustomNode({ errors: {} });
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("does not set hasErrors when data.errors values are all empty strings", () => {
renderCustomNode({ errors: { field1: "" } });
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("sets hasErrors when last execution result has error in output_data", () => {
renderCustomNode({
nodeExecutionResults: [
createExecutionResult({
output_data: { error: ["Something went wrong"] },
}),
],
});
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("true");
});
it("does not set hasErrors when execution results have no error", () => {
renderCustomNode({
nodeExecutionResults: [
createExecutionResult({
output_data: { result: ["success"] },
}),
],
});
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("false");
});
});
describe("NodeExecutionBadge", () => {
it("always renders NodeExecutionBadge for non-NOTE types", () => {
renderCustomNode({ uiType: BlockUIType.STANDARD });
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
it("renders NodeExecutionBadge for AGENT type", () => {
renderCustomNode({ uiType: BlockUIType.AGENT });
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
it("renders NodeExecutionBadge for OUTPUT type", () => {
renderCustomNode({ uiType: BlockUIType.OUTPUT });
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
});
describe("edge cases", () => {
it("renders without nodeExecutionResults", () => {
renderCustomNode({ nodeExecutionResults: undefined });
const container = screen.getByTestId("node-container");
expect(container).toBeDefined();
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("renders without errors property", () => {
renderCustomNode({ errors: undefined });
const container = screen.getByTestId("node-container");
expect(container).toBeDefined();
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("renders with empty execution results array", () => {
renderCustomNode({ nodeExecutionResults: [] });
const container = screen.getByTestId("node-container");
expect(container).toBeDefined();
expect(container.getAttribute("data-has-errors")).toBe("false");
});
});
});

View File

@@ -1,342 +0,0 @@
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
import {
render,
screen,
fireEvent,
waitFor,
cleanup,
} from "@/tests/integrations/test-utils";
import { useBlockMenuStore } from "../stores/blockMenuStore";
import { useControlPanelStore } from "../stores/controlPanelStore";
import { DefaultStateType } from "../components/NewControlPanel/NewBlockMenu/types";
import { SearchEntryFilterAnyOfItem } from "@/app/api/__generated__/models/searchEntryFilterAnyOfItem";
// ---------------------------------------------------------------------------
// Mocks for heavy child components
// ---------------------------------------------------------------------------
vi.mock(
"../components/NewControlPanel/NewBlockMenu/BlockMenuDefault/BlockMenuDefault",
() => ({
BlockMenuDefault: () => (
<div data-testid="block-menu-default">Default Content</div>
),
}),
);
vi.mock(
"../components/NewControlPanel/NewBlockMenu/BlockMenuSearch/BlockMenuSearch",
() => ({
BlockMenuSearch: () => (
<div data-testid="block-menu-search">Search Results</div>
),
}),
);
// Mock query client used by the search bar hook
vi.mock("@/lib/react-query/queryClient", () => ({
getQueryClient: () => ({
invalidateQueries: vi.fn(),
}),
}));
// ---------------------------------------------------------------------------
// Reset stores before each test
// ---------------------------------------------------------------------------
afterEach(() => {
cleanup();
});
beforeEach(() => {
useBlockMenuStore.getState().reset();
useBlockMenuStore.setState({
filters: [],
creators: [],
creators_list: [],
categoryCounts: {
blocks: 0,
integrations: 0,
marketplace_agents: 0,
my_agents: 0,
},
});
useControlPanelStore.getState().reset();
});
// ===========================================================================
// Section 1: blockMenuStore unit tests
// ===========================================================================
describe("blockMenuStore", () => {
describe("searchQuery", () => {
it("defaults to an empty string", () => {
expect(useBlockMenuStore.getState().searchQuery).toBe("");
});
it("sets the search query", () => {
useBlockMenuStore.getState().setSearchQuery("timer");
expect(useBlockMenuStore.getState().searchQuery).toBe("timer");
});
});
describe("defaultState", () => {
it("defaults to SUGGESTION", () => {
expect(useBlockMenuStore.getState().defaultState).toBe(
DefaultStateType.SUGGESTION,
);
});
it("sets the default state", () => {
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
expect(useBlockMenuStore.getState().defaultState).toBe(
DefaultStateType.ALL_BLOCKS,
);
});
});
describe("filters", () => {
it("defaults to an empty array", () => {
expect(useBlockMenuStore.getState().filters).toEqual([]);
});
it("adds a filter", () => {
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.blocks,
]);
});
it("removes a filter", () => {
useBlockMenuStore
.getState()
.setFilters([
SearchEntryFilterAnyOfItem.blocks,
SearchEntryFilterAnyOfItem.integrations,
]);
useBlockMenuStore
.getState()
.removeFilter(SearchEntryFilterAnyOfItem.blocks);
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.integrations,
]);
});
it("replaces all filters with setFilters", () => {
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
useBlockMenuStore
.getState()
.setFilters([SearchEntryFilterAnyOfItem.marketplace_agents]);
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.marketplace_agents,
]);
});
});
describe("creators", () => {
it("adds a creator", () => {
useBlockMenuStore.getState().addCreator("alice");
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
});
it("removes a creator", () => {
useBlockMenuStore.getState().setCreators(["alice", "bob"]);
useBlockMenuStore.getState().removeCreator("alice");
expect(useBlockMenuStore.getState().creators).toEqual(["bob"]);
});
it("replaces all creators with setCreators", () => {
useBlockMenuStore.getState().addCreator("alice");
useBlockMenuStore.getState().setCreators(["charlie"]);
expect(useBlockMenuStore.getState().creators).toEqual(["charlie"]);
});
});
describe("categoryCounts", () => {
it("sets category counts", () => {
const counts = {
blocks: 10,
integrations: 5,
marketplace_agents: 3,
my_agents: 2,
};
useBlockMenuStore.getState().setCategoryCounts(counts);
expect(useBlockMenuStore.getState().categoryCounts).toEqual(counts);
});
});
describe("searchId", () => {
it("defaults to undefined", () => {
expect(useBlockMenuStore.getState().searchId).toBeUndefined();
});
it("sets and clears searchId", () => {
useBlockMenuStore.getState().setSearchId("search-123");
expect(useBlockMenuStore.getState().searchId).toBe("search-123");
useBlockMenuStore.getState().setSearchId(undefined);
expect(useBlockMenuStore.getState().searchId).toBeUndefined();
});
});
describe("integration", () => {
it("defaults to undefined", () => {
expect(useBlockMenuStore.getState().integration).toBeUndefined();
});
it("sets the integration", () => {
useBlockMenuStore.getState().setIntegration("slack");
expect(useBlockMenuStore.getState().integration).toBe("slack");
});
});
describe("reset", () => {
it("resets searchQuery, searchId, defaultState, and integration", () => {
useBlockMenuStore.getState().setSearchQuery("hello");
useBlockMenuStore.getState().setSearchId("id-1");
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
useBlockMenuStore.getState().setIntegration("github");
useBlockMenuStore.getState().reset();
const state = useBlockMenuStore.getState();
expect(state.searchQuery).toBe("");
expect(state.searchId).toBeUndefined();
expect(state.defaultState).toBe(DefaultStateType.SUGGESTION);
expect(state.integration).toBeUndefined();
});
it("does not reset filters or creators (by design)", () => {
useBlockMenuStore
.getState()
.setFilters([SearchEntryFilterAnyOfItem.blocks]);
useBlockMenuStore.getState().setCreators(["alice"]);
useBlockMenuStore.getState().reset();
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.blocks,
]);
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
});
});
});
// ===========================================================================
// Section 2: controlPanelStore unit tests
// ===========================================================================
describe("controlPanelStore", () => {
it("defaults blockMenuOpen to false", () => {
expect(useControlPanelStore.getState().blockMenuOpen).toBe(false);
});
it("sets blockMenuOpen", () => {
useControlPanelStore.getState().setBlockMenuOpen(true);
expect(useControlPanelStore.getState().blockMenuOpen).toBe(true);
});
it("sets forceOpenBlockMenu", () => {
useControlPanelStore.getState().setForceOpenBlockMenu(true);
expect(useControlPanelStore.getState().forceOpenBlockMenu).toBe(true);
});
it("resets all control panel state", () => {
useControlPanelStore.getState().setBlockMenuOpen(true);
useControlPanelStore.getState().setForceOpenBlockMenu(true);
useControlPanelStore.getState().setSaveControlOpen(true);
useControlPanelStore.getState().setForceOpenSave(true);
useControlPanelStore.getState().reset();
const state = useControlPanelStore.getState();
expect(state.blockMenuOpen).toBe(false);
expect(state.forceOpenBlockMenu).toBe(false);
expect(state.saveControlOpen).toBe(false);
expect(state.forceOpenSave).toBe(false);
});
});
// ===========================================================================
// Section 3: BlockMenuContent integration tests
// ===========================================================================
// We import BlockMenuContent directly to avoid dealing with the Popover wrapper.
import { BlockMenuContent } from "../components/NewControlPanel/NewBlockMenu/BlockMenuContent/BlockMenuContent";
describe("BlockMenuContent", () => {
it("shows BlockMenuDefault when there is no search query", () => {
useBlockMenuStore.getState().setSearchQuery("");
render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-default")).toBeDefined();
expect(screen.queryByTestId("block-menu-search")).toBeNull();
});
it("shows BlockMenuSearch when a search query is present", () => {
useBlockMenuStore.getState().setSearchQuery("timer");
render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-search")).toBeDefined();
expect(screen.queryByTestId("block-menu-default")).toBeNull();
});
it("renders the search bar", () => {
render(<BlockMenuContent />);
expect(
screen.getByPlaceholderText(
"Blocks, Agents, Integrations or Keywords...",
),
).toBeDefined();
});
it("switches from default to search view when store query changes", () => {
const { rerender } = render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-default")).toBeDefined();
// Simulate typing by setting the store directly
useBlockMenuStore.getState().setSearchQuery("webhook");
rerender(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-search")).toBeDefined();
expect(screen.queryByTestId("block-menu-default")).toBeNull();
});
it("switches back to default view when search query is cleared", () => {
useBlockMenuStore.getState().setSearchQuery("something");
const { rerender } = render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-search")).toBeDefined();
useBlockMenuStore.getState().setSearchQuery("");
rerender(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-default")).toBeDefined();
expect(screen.queryByTestId("block-menu-search")).toBeNull();
});
it("typing in the search bar updates the local input value", async () => {
render(<BlockMenuContent />);
const input = screen.getByPlaceholderText(
"Blocks, Agents, Integrations or Keywords...",
);
fireEvent.change(input, { target: { value: "slack" } });
expect((input as HTMLInputElement).value).toBe("slack");
});
it("shows clear button when input has text and clears on click", async () => {
render(<BlockMenuContent />);
const input = screen.getByPlaceholderText(
"Blocks, Agents, Integrations or Keywords...",
);
fireEvent.change(input, { target: { value: "test" } });
// The clear button should appear
const clearButton = screen.getByRole("button");
fireEvent.click(clearButton);
await waitFor(() => {
expect((input as HTMLInputElement).value).toBe("");
});
});
});

View File

@@ -1,270 +0,0 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import {
render,
screen,
fireEvent,
waitFor,
cleanup,
} from "@/tests/integrations/test-utils";
import { UseFormReturn, useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import * as z from "zod";
import { renderHook } from "@testing-library/react";
import { useControlPanelStore } from "../stores/controlPanelStore";
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
import { NewSaveControl } from "../components/NewControlPanel/NewSaveControl/NewSaveControl";
import { useNewSaveControl } from "../components/NewControlPanel/NewSaveControl/useNewSaveControl";
const formSchema = z.object({
name: z.string().min(1, "Name is required").max(100),
description: z.string().max(500),
});
type SaveableGraphFormValues = z.infer<typeof formSchema>;
const mockHandleSave = vi.fn();
vi.mock(
"../components/NewControlPanel/NewSaveControl/useNewSaveControl",
() => ({
useNewSaveControl: vi.fn(),
}),
);
const mockUseNewSaveControl = vi.mocked(useNewSaveControl);
function createMockForm(
defaults: SaveableGraphFormValues = { name: "", description: "" },
): UseFormReturn<SaveableGraphFormValues> {
const { result } = renderHook(() =>
useForm<SaveableGraphFormValues>({
resolver: zodResolver(formSchema),
defaultValues: defaults,
}),
);
return result.current;
}
function setupMock(overrides: {
isSaving?: boolean;
graphVersion?: number;
name?: string;
description?: string;
}) {
const form = createMockForm({
name: overrides.name ?? "",
description: overrides.description ?? "",
});
mockUseNewSaveControl.mockReturnValue({
form,
isSaving: overrides.isSaving ?? false,
graphVersion: overrides.graphVersion,
handleSave: mockHandleSave,
});
return form;
}
function resetStore() {
useControlPanelStore.setState({
blockMenuOpen: false,
saveControlOpen: false,
forceOpenBlockMenu: false,
forceOpenSave: false,
});
}
beforeEach(() => {
cleanup();
resetStore();
mockHandleSave.mockReset();
});
afterEach(() => {
cleanup();
});
describe("NewSaveControl", () => {
it("renders save button trigger", () => {
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByTestId("save-control-save-button")).toBeDefined();
});
it("renders name and description inputs when popover is open", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByTestId("save-control-name-input")).toBeDefined();
expect(screen.getByTestId("save-control-description-input")).toBeDefined();
});
it("does not render popover content when closed", () => {
useControlPanelStore.setState({ saveControlOpen: false });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.queryByTestId("save-control-name-input")).toBeNull();
expect(screen.queryByTestId("save-control-description-input")).toBeNull();
});
it("shows version output when graphVersion is set", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ graphVersion: 3 });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const versionInput = screen.getByTestId("save-control-version-output");
expect(versionInput).toBeDefined();
expect((versionInput as HTMLInputElement).disabled).toBe(true);
});
it("hides version output when graphVersion is undefined", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ graphVersion: undefined });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.queryByTestId("save-control-version-output")).toBeNull();
});
it("enables save button when isSaving is false", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ isSaving: false });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByTestId("save-control-save-agent-button");
expect((saveButton as HTMLButtonElement).disabled).toBe(false);
});
it("disables save button when isSaving is true", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ isSaving: true });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByRole("button", { name: /save agent/i });
expect((saveButton as HTMLButtonElement).disabled).toBe(true);
});
it("calls handleSave on form submission with valid data", async () => {
useControlPanelStore.setState({ saveControlOpen: true });
const form = setupMock({ name: "My Agent", description: "A description" });
form.setValue("name", "My Agent");
form.setValue("description", "A description");
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByTestId("save-control-save-agent-button");
fireEvent.click(saveButton);
await waitFor(() => {
expect(mockHandleSave).toHaveBeenCalledWith(
{ name: "My Agent", description: "A description" },
expect.anything(),
);
});
});
it("does not call handleSave when name is empty (validation fails)", async () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ name: "", description: "" });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByTestId("save-control-save-agent-button");
fireEvent.click(saveButton);
await waitFor(() => {
expect(mockHandleSave).not.toHaveBeenCalled();
});
});
it("popover stays open when forceOpenSave is true", () => {
useControlPanelStore.setState({
saveControlOpen: false,
forceOpenSave: true,
});
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByTestId("save-control-name-input")).toBeDefined();
});
it("allows typing in name and description inputs", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const nameInput = screen.getByTestId(
"save-control-name-input",
) as HTMLInputElement;
const descriptionInput = screen.getByTestId(
"save-control-description-input",
) as HTMLInputElement;
fireEvent.change(nameInput, { target: { value: "Test Agent" } });
fireEvent.change(descriptionInput, {
target: { value: "Test Description" },
});
expect(nameInput.value).toBe("Test Agent");
expect(descriptionInput.value).toBe("Test Description");
});
it("displays save button text", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByText("Save Agent")).toBeDefined();
});
});

View File

@@ -1,147 +0,0 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { screen, fireEvent, cleanup } from "@testing-library/react";
import { render } from "@/tests/integrations/test-utils";
import React from "react";
import { useGraphStore } from "../stores/graphStore";
vi.mock(
"@/app/(platform)/build/components/BuilderActions/components/RunGraph/useRunGraph",
() => ({
useRunGraph: vi.fn(),
}),
);
vi.mock(
"@/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog",
() => ({
RunInputDialog: ({ isOpen }: { isOpen: boolean }) =>
isOpen ? <div data-testid="run-input-dialog">Dialog</div> : null,
}),
);
// Must import after mocks
import { useRunGraph } from "../components/BuilderActions/components/RunGraph/useRunGraph";
import { RunGraph } from "../components/BuilderActions/components/RunGraph/RunGraph";
const mockUseRunGraph = vi.mocked(useRunGraph);
function createMockReturnValue(
overrides: Partial<ReturnType<typeof useRunGraph>> = {},
) {
return {
handleRunGraph: vi.fn(),
handleStopGraph: vi.fn(),
openRunInputDialog: false,
setOpenRunInputDialog: vi.fn(),
isExecutingGraph: false,
isTerminatingGraph: false,
isSaving: false,
...overrides,
};
}
// RunGraph uses Tooltip which requires TooltipProvider
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
function renderRunGraph(flowID: string | null = "test-flow-id") {
return render(
<TooltipProvider>
<RunGraph flowID={flowID} />
</TooltipProvider>,
);
}
describe("RunGraph", () => {
beforeEach(() => {
cleanup();
mockUseRunGraph.mockReturnValue(createMockReturnValue());
useGraphStore.setState({ isGraphRunning: false });
});
afterEach(() => {
cleanup();
});
it("renders an enabled button when flowID is provided", () => {
renderRunGraph("test-flow-id");
const button = screen.getByRole("button");
expect((button as HTMLButtonElement).disabled).toBe(false);
});
it("renders a disabled button when flowID is null", () => {
renderRunGraph(null);
const button = screen.getByRole("button");
expect((button as HTMLButtonElement).disabled).toBe(true);
});
it("disables the button when isExecutingGraph is true", () => {
mockUseRunGraph.mockReturnValue(
createMockReturnValue({ isExecutingGraph: true }),
);
renderRunGraph();
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
true,
);
});
it("disables the button when isTerminatingGraph is true", () => {
mockUseRunGraph.mockReturnValue(
createMockReturnValue({ isTerminatingGraph: true }),
);
renderRunGraph();
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
true,
);
});
it("disables the button when isSaving is true", () => {
mockUseRunGraph.mockReturnValue(createMockReturnValue({ isSaving: true }));
renderRunGraph();
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
true,
);
});
it("uses data-id run-graph-button when not running", () => {
renderRunGraph();
const button = screen.getByRole("button");
expect(button.getAttribute("data-id")).toBe("run-graph-button");
});
it("uses data-id stop-graph-button when running", () => {
useGraphStore.setState({ isGraphRunning: true });
renderRunGraph();
const button = screen.getByRole("button");
expect(button.getAttribute("data-id")).toBe("stop-graph-button");
});
it("calls handleRunGraph when clicked and graph is not running", () => {
const handleRunGraph = vi.fn();
mockUseRunGraph.mockReturnValue(createMockReturnValue({ handleRunGraph }));
renderRunGraph();
fireEvent.click(screen.getByRole("button"));
expect(handleRunGraph).toHaveBeenCalledOnce();
});
it("calls handleStopGraph when clicked and graph is running", () => {
const handleStopGraph = vi.fn();
mockUseRunGraph.mockReturnValue(createMockReturnValue({ handleStopGraph }));
useGraphStore.setState({ isGraphRunning: true });
renderRunGraph();
fireEvent.click(screen.getByRole("button"));
expect(handleStopGraph).toHaveBeenCalledOnce();
});
it("renders RunInputDialog hidden by default", () => {
renderRunGraph();
expect(screen.queryByTestId("run-input-dialog")).toBeNull();
});
it("renders RunInputDialog when openRunInputDialog is true", () => {
mockUseRunGraph.mockReturnValue(
createMockReturnValue({ openRunInputDialog: true }),
);
renderRunGraph();
expect(screen.getByTestId("run-input-dialog")).toBeDefined();
});
});

View File

@@ -1,257 +0,0 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import { BlockUIType } from "../components/types";
vi.mock("@/services/storage/local-storage", () => {
const store: Record<string, string> = {};
return {
Key: { COPIED_FLOW_DATA: "COPIED_FLOW_DATA" },
storage: {
get: (key: string) => store[key] ?? null,
set: (key: string, value: string) => {
store[key] = value;
},
clean: (key: string) => {
delete store[key];
},
},
};
});
import { useCopyPasteStore } from "../stores/copyPasteStore";
import { useNodeStore } from "../stores/nodeStore";
import { useEdgeStore } from "../stores/edgeStore";
import { useHistoryStore } from "../stores/historyStore";
import { storage, Key } from "@/services/storage/local-storage";
function createTestNode(
id: string,
overrides: Partial<CustomNode> = {},
): CustomNode {
return {
id,
type: "custom",
position: overrides.position ?? { x: 100, y: 200 },
selected: overrides.selected,
data: {
hardcodedValues: {},
title: `Node ${id}`,
description: "test node",
inputSchema: {},
outputSchema: {},
uiType: BlockUIType.STANDARD,
block_id: `block-${id}`,
costs: [],
categories: [],
...overrides.data,
},
} as CustomNode;
}
describe("useCopyPasteStore", () => {
beforeEach(() => {
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
useEdgeStore.setState({ edges: [] });
useHistoryStore.getState().clear();
storage.clean(Key.COPIED_FLOW_DATA);
});
describe("copySelectedNodes", () => {
it("copies a single selected node to localStorage", () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
useCopyPasteStore.getState().copySelectedNodes();
const stored = storage.get(Key.COPIED_FLOW_DATA);
expect(stored).not.toBeNull();
const parsed = JSON.parse(stored!);
expect(parsed.nodes).toHaveLength(1);
expect(parsed.nodes[0].id).toBe("1");
expect(parsed.edges).toHaveLength(0);
});
it("copies only edges between selected nodes", () => {
const nodeA = createTestNode("a", { selected: true });
const nodeB = createTestNode("b", { selected: true });
const nodeC = createTestNode("c", { selected: false });
useNodeStore.setState({ nodes: [nodeA, nodeB, nodeC] });
useEdgeStore.setState({
edges: [
{
id: "e-ab",
source: "a",
target: "b",
sourceHandle: "out",
targetHandle: "in",
},
{
id: "e-bc",
source: "b",
target: "c",
sourceHandle: "out",
targetHandle: "in",
},
{
id: "e-ac",
source: "a",
target: "c",
sourceHandle: "out",
targetHandle: "in",
},
],
});
useCopyPasteStore.getState().copySelectedNodes();
const parsed = JSON.parse(storage.get(Key.COPIED_FLOW_DATA)!);
expect(parsed.nodes).toHaveLength(2);
expect(parsed.edges).toHaveLength(1);
expect(parsed.edges[0].id).toBe("e-ab");
});
it("stores empty data when no nodes are selected", () => {
const node = createTestNode("1", { selected: false });
useNodeStore.setState({ nodes: [node] });
useCopyPasteStore.getState().copySelectedNodes();
const parsed = JSON.parse(storage.get(Key.COPIED_FLOW_DATA)!);
expect(parsed.nodes).toHaveLength(0);
expect(parsed.edges).toHaveLength(0);
});
});
describe("pasteNodes", () => {
it("creates new nodes with new IDs via incrementNodeCounter", () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 200 },
});
useNodeStore.setState({ nodes: [node], nodeCounter: 5 });
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
const pastedNode = nodes.find((n) => n.id !== "orig");
expect(pastedNode).toBeDefined();
expect(pastedNode!.id).not.toBe("orig");
});
it("offsets pasted node positions by +50 x/y", () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 200 },
});
useNodeStore.setState({ nodes: [node], nodeCounter: 5 });
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
const pastedNode = nodes.find((n) => n.id !== "orig");
expect(pastedNode).toBeDefined();
expect(pastedNode!.position).toEqual({ x: 150, y: 250 });
});
it("preserves internal connections with remapped IDs", () => {
const nodeA = createTestNode("a", {
selected: true,
position: { x: 0, y: 0 },
});
const nodeB = createTestNode("b", {
selected: true,
position: { x: 200, y: 0 },
});
useNodeStore.setState({ nodes: [nodeA, nodeB], nodeCounter: 0 });
useEdgeStore.setState({
edges: [
{
id: "e-ab",
source: "a",
target: "b",
sourceHandle: "output",
targetHandle: "input",
},
],
});
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
const { edges } = useEdgeStore.getState();
const newEdges = edges.filter((e) => e.id !== "e-ab");
expect(newEdges).toHaveLength(1);
const newEdge = newEdges[0];
expect(newEdge.source).not.toBe("a");
expect(newEdge.target).not.toBe("b");
const { nodes } = useNodeStore.getState();
const pastedNodeIDs = nodes
.filter((n) => n.id !== "a" && n.id !== "b")
.map((n) => n.id);
expect(pastedNodeIDs).toContain(newEdge.source);
expect(pastedNodeIDs).toContain(newEdge.target);
});
it("deselects existing nodes and selects pasted ones", () => {
const existingNode = createTestNode("existing", {
selected: true,
position: { x: 0, y: 0 },
});
const nodeToCopy = createTestNode("copy-me", {
selected: true,
position: { x: 100, y: 100 },
});
useNodeStore.setState({
nodes: [existingNode, nodeToCopy],
nodeCounter: 0,
});
useCopyPasteStore.getState().copySelectedNodes();
// Deselect nodeToCopy, keep existingNode selected to verify deselection on paste
useNodeStore.setState({
nodes: [
{ ...existingNode, selected: true },
{ ...nodeToCopy, selected: false },
],
});
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
const originalNodes = nodes.filter(
(n) => n.id === "existing" || n.id === "copy-me",
);
const pastedNodes = nodes.filter(
(n) => n.id !== "existing" && n.id !== "copy-me",
);
originalNodes.forEach((n) => {
expect(n.selected).toBe(false);
});
pastedNodes.forEach((n) => {
expect(n.selected).toBe(true);
});
});
it("does nothing when clipboard is empty", () => {
const node = createTestNode("1", { position: { x: 0, y: 0 } });
useNodeStore.setState({ nodes: [node], nodeCounter: 0 });
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
});
});

View File

@@ -1,751 +0,0 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { MarkerType } from "@xyflow/react";
import { useEdgeStore } from "../stores/edgeStore";
import { useNodeStore } from "../stores/nodeStore";
import { useHistoryStore } from "../stores/historyStore";
import type { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
import type { Link } from "@/app/api/__generated__/models/link";
function makeEdge(overrides: Partial<CustomEdge> & { id: string }): CustomEdge {
return {
type: "custom",
source: "node-a",
target: "node-b",
sourceHandle: "output",
targetHandle: "input",
...overrides,
};
}
function makeExecutionResult(
overrides: Partial<NodeExecutionResult>,
): NodeExecutionResult {
return {
user_id: "user-1",
graph_id: "graph-1",
graph_version: 1,
graph_exec_id: "gexec-1",
node_exec_id: "nexec-1",
node_id: "node-1",
block_id: "block-1",
status: "INCOMPLETE",
input_data: {},
output_data: {},
add_time: new Date(),
queue_time: null,
start_time: null,
end_time: null,
...overrides,
};
}
beforeEach(() => {
useEdgeStore.setState({ edges: [] });
useNodeStore.setState({ nodes: [] });
useHistoryStore.setState({ past: [], future: [] });
});
describe("edgeStore", () => {
describe("setEdges", () => {
it("replaces all edges", () => {
const edges = [
makeEdge({ id: "e1" }),
makeEdge({ id: "e2", source: "node-c" }),
];
useEdgeStore.getState().setEdges(edges);
expect(useEdgeStore.getState().edges).toHaveLength(2);
expect(useEdgeStore.getState().edges[0].id).toBe("e1");
expect(useEdgeStore.getState().edges[1].id).toBe("e2");
});
});
describe("addEdge", () => {
it("adds an edge and auto-generates an ID", () => {
const result = useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(result.id).toBe("n1:out->n2:in");
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(useEdgeStore.getState().edges[0].id).toBe("n1:out->n2:in");
});
it("uses provided ID when given", () => {
const result = useEdgeStore.getState().addEdge({
id: "custom-id",
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(result.id).toBe("custom-id");
});
it("sets type to custom and adds arrow marker", () => {
const result = useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(result.type).toBe("custom");
expect(result.markerEnd).toEqual({
type: MarkerType.ArrowClosed,
strokeWidth: 2,
color: "#555",
});
});
it("rejects duplicate edges without adding", () => {
useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
const duplicate = useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(duplicate.id).toBe("n1:out->n2:in");
expect(pushSpy).not.toHaveBeenCalled();
pushSpy.mockRestore();
});
it("pushes previous state to history store", () => {
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(pushSpy).toHaveBeenCalledWith({
nodes: [],
edges: [],
});
pushSpy.mockRestore();
});
});
describe("removeEdge", () => {
it("removes an edge by ID", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1" }), makeEdge({ id: "e2" })],
});
useEdgeStore.getState().removeEdge("e1");
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(useEdgeStore.getState().edges[0].id).toBe("e2");
});
it("does nothing when removing a non-existent edge", () => {
useEdgeStore.setState({ edges: [makeEdge({ id: "e1" })] });
useEdgeStore.getState().removeEdge("nonexistent");
expect(useEdgeStore.getState().edges).toHaveLength(1);
});
it("pushes previous state to history store", () => {
const existingEdges = [makeEdge({ id: "e1" })];
useEdgeStore.setState({ edges: existingEdges });
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
useEdgeStore.getState().removeEdge("e1");
expect(pushSpy).toHaveBeenCalledWith({
nodes: [],
edges: existingEdges,
});
pushSpy.mockRestore();
});
});
describe("upsertMany", () => {
it("inserts new edges", () => {
useEdgeStore.setState({ edges: [makeEdge({ id: "e1" })] });
useEdgeStore.getState().upsertMany([makeEdge({ id: "e2" })]);
expect(useEdgeStore.getState().edges).toHaveLength(2);
});
it("updates existing edges by ID", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "old-source" })],
});
useEdgeStore
.getState()
.upsertMany([makeEdge({ id: "e1", source: "new-source" })]);
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(useEdgeStore.getState().edges[0].source).toBe("new-source");
});
it("handles mixed inserts and updates", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "old" })],
});
useEdgeStore
.getState()
.upsertMany([
makeEdge({ id: "e1", source: "updated" }),
makeEdge({ id: "e2", source: "new" }),
]);
const edges = useEdgeStore.getState().edges;
expect(edges).toHaveLength(2);
expect(edges.find((e) => e.id === "e1")?.source).toBe("updated");
expect(edges.find((e) => e.id === "e2")?.source).toBe("new");
});
});
describe("removeEdgesByHandlePrefix", () => {
it("removes edges targeting a node with matching handle prefix", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", target: "node-b", targetHandle: "input_foo" }),
makeEdge({ id: "e2", target: "node-b", targetHandle: "input_bar" }),
makeEdge({
id: "e3",
target: "node-b",
targetHandle: "other_handle",
}),
makeEdge({ id: "e4", target: "node-c", targetHandle: "input_foo" }),
],
});
useEdgeStore.getState().removeEdgesByHandlePrefix("node-b", "input_");
const edges = useEdgeStore.getState().edges;
expect(edges).toHaveLength(2);
expect(edges.map((e) => e.id).sort()).toEqual(["e3", "e4"]);
});
it("does not remove edges where target does not match nodeId", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-b",
target: "node-c",
targetHandle: "input_x",
}),
],
});
useEdgeStore.getState().removeEdgesByHandlePrefix("node-b", "input_");
expect(useEdgeStore.getState().edges).toHaveLength(1);
});
});
describe("getNodeEdges", () => {
it("returns edges where node is source", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
makeEdge({ id: "e2", source: "node-c", target: "node-d" }),
],
});
const result = useEdgeStore.getState().getNodeEdges("node-a");
expect(result).toHaveLength(1);
expect(result[0].id).toBe("e1");
});
it("returns edges where node is target", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
makeEdge({ id: "e2", source: "node-c", target: "node-d" }),
],
});
const result = useEdgeStore.getState().getNodeEdges("node-b");
expect(result).toHaveLength(1);
expect(result[0].id).toBe("e1");
});
it("returns edges for both source and target", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
makeEdge({ id: "e2", source: "node-b", target: "node-c" }),
makeEdge({ id: "e3", source: "node-d", target: "node-e" }),
],
});
const result = useEdgeStore.getState().getNodeEdges("node-b");
expect(result).toHaveLength(2);
expect(result.map((e) => e.id).sort()).toEqual(["e1", "e2"]);
});
it("returns empty array for unconnected node", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "node-a", target: "node-b" })],
});
expect(useEdgeStore.getState().getNodeEdges("node-z")).toHaveLength(0);
});
});
describe("isInputConnected", () => {
it("returns true when target handle is connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
}),
],
});
expect(useEdgeStore.getState().isInputConnected("node-b", "input")).toBe(
true,
);
});
it("returns false when target handle is not connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
}),
],
});
expect(useEdgeStore.getState().isInputConnected("node-b", "other")).toBe(
false,
);
});
it("returns false when node is source not target", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-b",
target: "node-c",
sourceHandle: "output",
targetHandle: "input",
}),
],
});
expect(useEdgeStore.getState().isInputConnected("node-b", "output")).toBe(
false,
);
});
});
describe("isOutputConnected", () => {
it("returns true when source handle is connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-a",
sourceHandle: "output",
}),
],
});
expect(
useEdgeStore.getState().isOutputConnected("node-a", "output"),
).toBe(true);
});
it("returns false when source handle is not connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-a",
sourceHandle: "output",
}),
],
});
expect(useEdgeStore.getState().isOutputConnected("node-a", "other")).toBe(
false,
);
});
});
describe("getBackendLinks", () => {
it("converts edges to Link format", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
data: { isStatic: true },
}),
],
});
const links = useEdgeStore.getState().getBackendLinks();
expect(links).toHaveLength(1);
expect(links[0]).toEqual({
id: "e1",
source_id: "n1",
sink_id: "n2",
source_name: "out",
sink_name: "in",
is_static: true,
});
});
});
describe("addLinks", () => {
it("converts Links to edges and adds them", () => {
const links: Link[] = [
{
id: "link-1",
source_id: "n1",
sink_id: "n2",
source_name: "out",
sink_name: "in",
is_static: false,
},
];
useEdgeStore.getState().addLinks(links);
const edges = useEdgeStore.getState().edges;
expect(edges).toHaveLength(1);
expect(edges[0].source).toBe("n1");
expect(edges[0].target).toBe("n2");
expect(edges[0].sourceHandle).toBe("out");
expect(edges[0].targetHandle).toBe("in");
expect(edges[0].data?.isStatic).toBe(false);
});
it("adds multiple links", () => {
const links: Link[] = [
{
id: "link-1",
source_id: "n1",
sink_id: "n2",
source_name: "out",
sink_name: "in",
},
{
id: "link-2",
source_id: "n3",
sink_id: "n4",
source_name: "result",
sink_name: "value",
},
];
useEdgeStore.getState().addLinks(links);
expect(useEdgeStore.getState().edges).toHaveLength(2);
});
});
describe("getAllHandleIdsOfANode", () => {
it("returns targetHandle values for edges targeting the node", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", target: "node-b", targetHandle: "input_a" }),
makeEdge({ id: "e2", target: "node-b", targetHandle: "input_b" }),
makeEdge({ id: "e3", target: "node-c", targetHandle: "input_c" }),
],
});
const handles = useEdgeStore.getState().getAllHandleIdsOfANode("node-b");
expect(handles).toEqual(["input_a", "input_b"]);
});
it("returns empty array when no edges target the node", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "node-b", target: "node-c" })],
});
expect(useEdgeStore.getState().getAllHandleIdsOfANode("node-b")).toEqual(
[],
);
});
it("returns empty string for edges with no targetHandle", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: undefined,
}),
],
});
expect(useEdgeStore.getState().getAllHandleIdsOfANode("node-b")).toEqual([
"",
]);
});
});
describe("updateEdgeBeads", () => {
it("updates bead counts for edges targeting the node", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "some-value" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(1);
expect(edge.data?.beadDown).toBe(1);
});
it("counts INCOMPLETE status in beadUp but not beadDown", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "INCOMPLETE",
input_data: { input: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(1);
expect(edge.data?.beadDown).toBe(0);
});
it("does not modify edges not targeting the node", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-c",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(0);
expect(edge.data?.beadDown).toBe(0);
});
it("does not update edge when input_data has no matching handle", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { other_handle: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(0);
expect(edge.data?.beadDown).toBe(0);
});
it("accumulates beads across multiple executions", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "data1" },
}),
);
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-2",
status: "INCOMPLETE",
input_data: { input: "data2" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(2);
expect(edge.data?.beadDown).toBe(1);
});
it("handles static edges by setting beadUp to beadDown + 1", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: {
isStatic: true,
beadUp: 0,
beadDown: 0,
beadData: new Map(),
},
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(2);
expect(edge.data?.beadDown).toBe(1);
});
});
describe("resetEdgeBeads", () => {
it("resets all bead data on all edges", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
data: {
beadUp: 5,
beadDown: 3,
beadData: new Map([["exec-1", "COMPLETED"]]),
},
}),
makeEdge({
id: "e2",
data: {
beadUp: 2,
beadDown: 1,
beadData: new Map([["exec-2", "INCOMPLETE"]]),
},
}),
],
});
useEdgeStore.getState().resetEdgeBeads();
const edges = useEdgeStore.getState().edges;
for (const edge of edges) {
expect(edge.data?.beadUp).toBe(0);
expect(edge.data?.beadDown).toBe(0);
expect(edge.data?.beadData?.size).toBe(0);
}
});
it("preserves other edge data when resetting beads", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
data: {
isStatic: true,
edgeColorClass: "text-red-500",
beadUp: 3,
beadDown: 2,
beadData: new Map(),
},
}),
],
});
useEdgeStore.getState().resetEdgeBeads();
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.isStatic).toBe(true);
expect(edge.data?.edgeColorClass).toBe("text-red-500");
expect(edge.data?.beadUp).toBe(0);
});
});
});

View File

@@ -1,347 +0,0 @@
import { describe, it, expect, beforeEach } from "vitest";
import { useGraphStore } from "../stores/graphStore";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
function createTestGraphMeta(
overrides: Partial<GraphMeta> & { id: string; name: string },
): GraphMeta {
return {
version: 1,
description: "",
is_active: true,
user_id: "test-user",
created_at: new Date("2024-01-01T00:00:00Z"),
...overrides,
};
}
function resetStore() {
useGraphStore.setState({
graphExecutionStatus: undefined,
isGraphRunning: false,
inputSchema: null,
credentialsInputSchema: null,
outputSchema: null,
availableSubGraphs: [],
});
}
beforeEach(() => {
resetStore();
});
describe("graphStore", () => {
describe("execution status transitions", () => {
it("handles QUEUED -> RUNNING -> COMPLETED transition", () => {
const { setGraphExecutionStatus } = useGraphStore.getState();
setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.QUEUED,
);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.RUNNING,
);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.COMPLETED);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.COMPLETED,
);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("handles QUEUED -> RUNNING -> FAILED transition", () => {
const { setGraphExecutionStatus } = useGraphStore.getState();
setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.FAILED);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.FAILED,
);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
});
describe("setGraphExecutionStatus auto-sets isGraphRunning", () => {
it("sets isGraphRunning to true for RUNNING", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
});
it("sets isGraphRunning to true for QUEUED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
});
it("sets isGraphRunning to false for COMPLETED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.COMPLETED);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for FAILED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.FAILED);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for TERMINATED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.TERMINATED);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for INCOMPLETE", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.INCOMPLETE);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for undefined", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
useGraphStore.getState().setGraphExecutionStatus(undefined);
expect(useGraphStore.getState().graphExecutionStatus).toBeUndefined();
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
});
describe("setIsGraphRunning", () => {
it("sets isGraphRunning independently of status", () => {
useGraphStore.getState().setIsGraphRunning(true);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
useGraphStore.getState().setIsGraphRunning(false);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
});
describe("schema management", () => {
it("sets all three schemas via setGraphSchemas", () => {
const input = { properties: { prompt: { type: "string" } } };
const credentials = { properties: { apiKey: { type: "string" } } };
const output = { properties: { result: { type: "string" } } };
useGraphStore.getState().setGraphSchemas(input, credentials, output);
const state = useGraphStore.getState();
expect(state.inputSchema).toEqual(input);
expect(state.credentialsInputSchema).toEqual(credentials);
expect(state.outputSchema).toEqual(output);
});
it("sets schemas to null", () => {
const input = { properties: { prompt: { type: "string" } } };
useGraphStore.getState().setGraphSchemas(input, null, null);
const state = useGraphStore.getState();
expect(state.inputSchema).toEqual(input);
expect(state.credentialsInputSchema).toBeNull();
expect(state.outputSchema).toBeNull();
});
it("overwrites previous schemas", () => {
const first = { properties: { a: { type: "string" } } };
const second = { properties: { b: { type: "number" } } };
useGraphStore.getState().setGraphSchemas(first, first, first);
useGraphStore.getState().setGraphSchemas(second, null, second);
const state = useGraphStore.getState();
expect(state.inputSchema).toEqual(second);
expect(state.credentialsInputSchema).toBeNull();
expect(state.outputSchema).toEqual(second);
});
});
describe("hasInputs", () => {
it("returns false when inputSchema is null", () => {
expect(useGraphStore.getState().hasInputs()).toBe(false);
});
it("returns false when inputSchema has no properties", () => {
useGraphStore.getState().setGraphSchemas({}, null, null);
expect(useGraphStore.getState().hasInputs()).toBe(false);
});
it("returns false when inputSchema has empty properties", () => {
useGraphStore.getState().setGraphSchemas({ properties: {} }, null, null);
expect(useGraphStore.getState().hasInputs()).toBe(false);
});
it("returns true when inputSchema has properties", () => {
useGraphStore
.getState()
.setGraphSchemas(
{ properties: { prompt: { type: "string" } } },
null,
null,
);
expect(useGraphStore.getState().hasInputs()).toBe(true);
});
});
describe("hasCredentials", () => {
it("returns false when credentialsInputSchema is null", () => {
expect(useGraphStore.getState().hasCredentials()).toBe(false);
});
it("returns false when credentialsInputSchema has empty properties", () => {
useGraphStore.getState().setGraphSchemas(null, { properties: {} }, null);
expect(useGraphStore.getState().hasCredentials()).toBe(false);
});
it("returns true when credentialsInputSchema has properties", () => {
useGraphStore
.getState()
.setGraphSchemas(
null,
{ properties: { apiKey: { type: "string" } } },
null,
);
expect(useGraphStore.getState().hasCredentials()).toBe(true);
});
});
describe("hasOutputs", () => {
it("returns false when outputSchema is null", () => {
expect(useGraphStore.getState().hasOutputs()).toBe(false);
});
it("returns false when outputSchema has empty properties", () => {
useGraphStore.getState().setGraphSchemas(null, null, { properties: {} });
expect(useGraphStore.getState().hasOutputs()).toBe(false);
});
it("returns true when outputSchema has properties", () => {
useGraphStore.getState().setGraphSchemas(null, null, {
properties: { result: { type: "string" } },
});
expect(useGraphStore.getState().hasOutputs()).toBe(true);
});
});
describe("reset", () => {
it("clears execution status and schemas but preserves outputSchema and availableSubGraphs", () => {
const subGraphs: GraphMeta[] = [
createTestGraphMeta({
id: "sub-1",
name: "Sub Graph",
description: "A sub graph",
}),
];
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphSchemas(
{ properties: { a: {} } },
{ properties: { b: {} } },
{ properties: { c: {} } },
);
useGraphStore.getState().setAvailableSubGraphs(subGraphs);
useGraphStore.getState().reset();
const state = useGraphStore.getState();
expect(state.graphExecutionStatus).toBeUndefined();
expect(state.isGraphRunning).toBe(false);
expect(state.inputSchema).toBeNull();
expect(state.credentialsInputSchema).toBeNull();
// reset does not clear outputSchema or availableSubGraphs
expect(state.outputSchema).toEqual({ properties: { c: {} } });
expect(state.availableSubGraphs).toEqual(subGraphs);
});
it("is idempotent on fresh state", () => {
useGraphStore.getState().reset();
const state = useGraphStore.getState();
expect(state.graphExecutionStatus).toBeUndefined();
expect(state.isGraphRunning).toBe(false);
expect(state.inputSchema).toBeNull();
expect(state.credentialsInputSchema).toBeNull();
});
});
describe("setAvailableSubGraphs", () => {
it("sets sub-graphs list", () => {
const graphs: GraphMeta[] = [
createTestGraphMeta({
id: "graph-1",
name: "Graph One",
description: "First graph",
}),
createTestGraphMeta({
id: "graph-2",
version: 2,
name: "Graph Two",
description: "Second graph",
}),
];
useGraphStore.getState().setAvailableSubGraphs(graphs);
expect(useGraphStore.getState().availableSubGraphs).toEqual(graphs);
});
it("replaces previous sub-graphs", () => {
const first: GraphMeta[] = [createTestGraphMeta({ id: "a", name: "A" })];
const second: GraphMeta[] = [
createTestGraphMeta({ id: "b", name: "B" }),
createTestGraphMeta({ id: "c", name: "C" }),
];
useGraphStore.getState().setAvailableSubGraphs(first);
expect(useGraphStore.getState().availableSubGraphs).toHaveLength(1);
useGraphStore.getState().setAvailableSubGraphs(second);
expect(useGraphStore.getState().availableSubGraphs).toHaveLength(2);
expect(useGraphStore.getState().availableSubGraphs).toEqual(second);
});
it("can set empty sub-graphs list", () => {
useGraphStore
.getState()
.setAvailableSubGraphs([createTestGraphMeta({ id: "x", name: "X" })]);
useGraphStore.getState().setAvailableSubGraphs([]);
expect(useGraphStore.getState().availableSubGraphs).toEqual([]);
});
});
});

View File

@@ -1,407 +0,0 @@
import { describe, it, expect, beforeEach } from "vitest";
import { useHistoryStore } from "../stores/historyStore";
import { useNodeStore } from "../stores/nodeStore";
import { useEdgeStore } from "../stores/edgeStore";
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
function createTestNode(
id: string,
overrides: Partial<CustomNode> = {},
): CustomNode {
return {
id,
type: "custom" as const,
position: { x: 0, y: 0 },
data: {
hardcodedValues: {},
title: `Node ${id}`,
description: "",
inputSchema: {},
outputSchema: {},
uiType: "STANDARD" as never,
block_id: `block-${id}`,
costs: [],
categories: [],
},
...overrides,
} as CustomNode;
}
function createTestEdge(
id: string,
source: string,
target: string,
): CustomEdge {
return {
id,
source,
target,
type: "custom" as const,
} as CustomEdge;
}
async function flushMicrotasks() {
await new Promise<void>((resolve) => queueMicrotask(resolve));
}
beforeEach(() => {
useHistoryStore.getState().clear();
useNodeStore.setState({ nodes: [] });
useEdgeStore.setState({ edges: [] });
});
describe("historyStore", () => {
describe("undo/redo single action", () => {
it("undoes a single pushed state", async () => {
const node = createTestNode("1");
// Initialize history with node present as baseline
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().initializeHistory();
// Simulate a change: clear nodes
useNodeStore.setState({ nodes: [] });
// Undo should restore to [node]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node]);
expect(useHistoryStore.getState().future).toHaveLength(1);
expect(useHistoryStore.getState().future[0].nodes).toEqual([]);
});
it("redoes after undo", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().initializeHistory();
// Change: clear nodes
useNodeStore.setState({ nodes: [] });
// Undo → back to [node]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node]);
// Redo → back to []
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([]);
});
});
describe("undo/redo multiple actions", () => {
it("undoes through multiple states in order", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
// Initialize with [node1] as baseline
useNodeStore.setState({ nodes: [node1] });
useHistoryStore.getState().initializeHistory();
// Second change: add node2, push pre-change state
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
// Third change: add node3, push pre-change state
useNodeStore.setState({ nodes: [node1, node2, node3] });
useHistoryStore
.getState()
.pushState({ nodes: [node1, node2], edges: [] });
await flushMicrotasks();
// Undo 1: back to [node1, node2]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
// Undo 2: back to [node1]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1]);
});
});
describe("undo past empty history", () => {
it("does nothing when there is no history to undo", () => {
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([]);
expect(useEdgeStore.getState().edges).toEqual([]);
expect(useHistoryStore.getState().past).toHaveLength(1);
});
it("does nothing when current state equals last past entry", () => {
expect(useHistoryStore.getState().canUndo()).toBe(false);
useHistoryStore.getState().undo();
expect(useHistoryStore.getState().past).toHaveLength(1);
expect(useHistoryStore.getState().future).toHaveLength(0);
});
});
describe("state consistency: undo after node add restores previous, redo restores added", () => {
it("undo removes added node, redo restores it", async () => {
const node = createTestNode("added");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([]);
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([node]);
});
});
describe("history limits", () => {
it("does not grow past MAX_HISTORY (50)", async () => {
for (let i = 0; i < 60; i++) {
const node = createTestNode(`node-${i}`);
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({
nodes: [createTestNode(`node-${i - 1}`)],
edges: [],
});
await flushMicrotasks();
}
expect(useHistoryStore.getState().past.length).toBeLessThanOrEqual(50);
});
});
describe("edge cases", () => {
it("redo does nothing when future is empty", () => {
const nodesBefore = useNodeStore.getState().nodes;
const edgesBefore = useEdgeStore.getState().edges;
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual(nodesBefore);
expect(useEdgeStore.getState().edges).toEqual(edgesBefore);
});
it("interleaved undo/redo sequence", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
useNodeStore.setState({ nodes: [node1] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
useNodeStore.setState({ nodes: [node1, node2, node3] });
useHistoryStore.getState().pushState({
nodes: [node1, node2],
edges: [],
});
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1]);
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1]);
useHistoryStore.getState().redo();
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2, node3]);
});
});
describe("canUndo / canRedo", () => {
it("canUndo is false on fresh store", () => {
expect(useHistoryStore.getState().canUndo()).toBe(false);
});
it("canUndo is true when current state differs from last past entry", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().canUndo()).toBe(true);
});
it("canRedo is false on fresh store", () => {
expect(useHistoryStore.getState().canRedo()).toBe(false);
});
it("canRedo is true after undo", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useHistoryStore.getState().canRedo()).toBe(true);
});
it("canRedo becomes false after redo exhausts future", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
useHistoryStore.getState().redo();
expect(useHistoryStore.getState().canRedo()).toBe(false);
});
});
describe("pushState deduplication", () => {
it("does not push a state identical to the last past entry", async () => {
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().past).toHaveLength(1);
});
it("does not push if state matches current node/edge store state", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useEdgeStore.setState({ edges: [] });
useHistoryStore.getState().pushState({ nodes: [node], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().past).toHaveLength(1);
});
});
describe("initializeHistory", () => {
it("resets history with current node/edge store state", async () => {
const node = createTestNode("1");
const edge = createTestEdge("e1", "1", "2");
useNodeStore.setState({ nodes: [node] });
useEdgeStore.setState({ edges: [edge] });
useNodeStore.setState({ nodes: [node, createTestNode("2")] });
useHistoryStore.getState().pushState({ nodes: [node], edges: [edge] });
await flushMicrotasks();
useHistoryStore.getState().initializeHistory();
const { past, future } = useHistoryStore.getState();
expect(past).toHaveLength(1);
expect(past[0].nodes).toEqual(useNodeStore.getState().nodes);
expect(past[0].edges).toEqual(useEdgeStore.getState().edges);
expect(future).toHaveLength(0);
});
});
describe("clear", () => {
it("resets to empty initial state", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().clear();
const { past, future } = useHistoryStore.getState();
expect(past).toEqual([{ nodes: [], edges: [] }]);
expect(future).toEqual([]);
});
});
describe("microtask batching", () => {
it("only commits the first state when multiple pushState calls happen in the same tick", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
useNodeStore.setState({ nodes: [node1, node2, node3] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
useHistoryStore.getState().pushState({ nodes: [node2], edges: [] });
useHistoryStore
.getState()
.pushState({ nodes: [node1, node2], edges: [] });
await flushMicrotasks();
const { past } = useHistoryStore.getState();
expect(past).toHaveLength(2);
expect(past[1].nodes).toEqual([node1]);
});
it("commits separately when pushState calls are in different ticks", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().pushState({ nodes: [node2], edges: [] });
await flushMicrotasks();
const { past } = useHistoryStore.getState();
expect(past).toHaveLength(3);
expect(past[1].nodes).toEqual([node1]);
expect(past[2].nodes).toEqual([node2]);
});
});
describe("edges in undo/redo", () => {
it("restores edges on undo and redo", async () => {
const edge = createTestEdge("e1", "1", "2");
useEdgeStore.setState({ edges: [edge] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useEdgeStore.getState().edges).toEqual([]);
useHistoryStore.getState().redo();
expect(useEdgeStore.getState().edges).toEqual([edge]);
});
});
describe("pushState clears future", () => {
it("clears future when a new state is pushed after undo", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
// Initialize empty
useHistoryStore.getState().initializeHistory();
// First change: set [node1]
useNodeStore.setState({ nodes: [node1] });
// Second change: set [node1, node2], push pre-change [node1]
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
// Undo: back to [node1]
useHistoryStore.getState().undo();
expect(useHistoryStore.getState().future).toHaveLength(1);
// New diverging change: add node3 instead of node2
useNodeStore.setState({ nodes: [node1, node3] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().future).toHaveLength(0);
});
});
});

View File

@@ -1,791 +0,0 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { useNodeStore } from "../stores/nodeStore";
import { useHistoryStore } from "../stores/historyStore";
import { useEdgeStore } from "../stores/edgeStore";
import { BlockUIType } from "../components/types";
import type { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import type { CustomNodeData } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
function createTestNode(overrides: {
id: string;
position?: { x: number; y: number };
data?: Partial<CustomNodeData>;
}): CustomNode {
const defaults: CustomNodeData = {
hardcodedValues: {},
title: "Test Block",
description: "A test block",
inputSchema: {},
outputSchema: {},
uiType: BlockUIType.STANDARD,
block_id: "test-block-id",
costs: [],
categories: [],
};
return {
id: overrides.id,
type: "custom",
position: overrides.position ?? { x: 0, y: 0 },
data: { ...defaults, ...overrides.data },
};
}
function createExecutionResult(
overrides: Partial<NodeExecutionResult> = {},
): NodeExecutionResult {
return {
node_exec_id: overrides.node_exec_id ?? "exec-1",
node_id: overrides.node_id ?? "1",
graph_exec_id: overrides.graph_exec_id ?? "graph-exec-1",
graph_id: overrides.graph_id ?? "graph-1",
graph_version: overrides.graph_version ?? 1,
user_id: overrides.user_id ?? "test-user",
block_id: overrides.block_id ?? "block-1",
status: overrides.status ?? "COMPLETED",
input_data: overrides.input_data ?? { input_key: "input_value" },
output_data: overrides.output_data ?? { output_key: ["output_value"] },
add_time: overrides.add_time ?? new Date("2024-01-01T00:00:00Z"),
queue_time: overrides.queue_time ?? new Date("2024-01-01T00:00:00Z"),
start_time: overrides.start_time ?? new Date("2024-01-01T00:00:01Z"),
end_time: overrides.end_time ?? new Date("2024-01-01T00:00:02Z"),
};
}
function resetStores() {
useNodeStore.setState({
nodes: [],
nodeCounter: 0,
nodeAdvancedStates: {},
latestNodeInputData: {},
latestNodeOutputData: {},
accumulatedNodeInputData: {},
accumulatedNodeOutputData: {},
nodesInResolutionMode: new Set(),
brokenEdgeIDs: new Map(),
nodeResolutionData: new Map(),
});
useEdgeStore.setState({ edges: [] });
useHistoryStore.setState({ past: [], future: [] });
}
describe("nodeStore", () => {
beforeEach(() => {
resetStores();
vi.restoreAllMocks();
});
describe("node lifecycle", () => {
it("starts with empty nodes", () => {
const { nodes } = useNodeStore.getState();
expect(nodes).toEqual([]);
});
it("adds a single node with addNode", () => {
const node = createTestNode({ id: "1" });
useNodeStore.getState().addNode(node);
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
it("sets nodes with setNodes, replacing existing ones", () => {
const node1 = createTestNode({ id: "1" });
const node2 = createTestNode({ id: "2" });
useNodeStore.getState().addNode(node1);
useNodeStore.getState().setNodes([node2]);
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("2");
});
it("removes nodes via onNodesChange", () => {
const node = createTestNode({ id: "1" });
useNodeStore.getState().setNodes([node]);
useNodeStore.getState().onNodesChange([{ type: "remove", id: "1" }]);
expect(useNodeStore.getState().nodes).toHaveLength(0);
});
it("updates node data with updateNodeData", () => {
const node = createTestNode({ id: "1" });
useNodeStore.getState().addNode(node);
useNodeStore.getState().updateNodeData("1", { title: "Updated Title" });
const updated = useNodeStore.getState().nodes[0];
expect(updated.data.title).toBe("Updated Title");
expect(updated.data.block_id).toBe("test-block-id");
});
it("updateNodeData does not affect other nodes", () => {
const node1 = createTestNode({ id: "1" });
const node2 = createTestNode({
id: "2",
data: { title: "Node 2" },
});
useNodeStore.getState().setNodes([node1, node2]);
useNodeStore.getState().updateNodeData("1", { title: "Changed" });
expect(useNodeStore.getState().nodes[1].data.title).toBe("Node 2");
});
});
describe("bulk operations", () => {
it("adds multiple nodes with addNodes", () => {
const nodes = [
createTestNode({ id: "1" }),
createTestNode({ id: "2" }),
createTestNode({ id: "3" }),
];
useNodeStore.getState().addNodes(nodes);
expect(useNodeStore.getState().nodes).toHaveLength(3);
});
it("removes multiple nodes via onNodesChange", () => {
const nodes = [
createTestNode({ id: "1" }),
createTestNode({ id: "2" }),
createTestNode({ id: "3" }),
];
useNodeStore.getState().setNodes(nodes);
useNodeStore.getState().onNodesChange([
{ type: "remove", id: "1" },
{ type: "remove", id: "3" },
]);
const remaining = useNodeStore.getState().nodes;
expect(remaining).toHaveLength(1);
expect(remaining[0].id).toBe("2");
});
});
describe("nodeCounter", () => {
it("starts at zero", () => {
expect(useNodeStore.getState().nodeCounter).toBe(0);
});
it("increments the counter", () => {
useNodeStore.getState().incrementNodeCounter();
expect(useNodeStore.getState().nodeCounter).toBe(1);
useNodeStore.getState().incrementNodeCounter();
expect(useNodeStore.getState().nodeCounter).toBe(2);
});
it("sets the counter to a specific value", () => {
useNodeStore.getState().setNodeCounter(42);
expect(useNodeStore.getState().nodeCounter).toBe(42);
});
});
describe("advanced states", () => {
it("defaults to false for unknown node IDs", () => {
expect(useNodeStore.getState().getShowAdvanced("unknown")).toBe(false);
});
it("toggles advanced state", () => {
useNodeStore.getState().toggleAdvanced("node-1");
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(true);
useNodeStore.getState().toggleAdvanced("node-1");
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(false);
});
it("sets advanced state explicitly", () => {
useNodeStore.getState().setShowAdvanced("node-1", true);
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(true);
useNodeStore.getState().setShowAdvanced("node-1", false);
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(false);
});
});
describe("convertCustomNodeToBackendNode", () => {
it("converts a node with minimal data", () => {
const node = createTestNode({
id: "42",
position: { x: 100, y: 200 },
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.id).toBe("42");
expect(backend.block_id).toBe("test-block-id");
expect(backend.input_default).toEqual({});
expect(backend.metadata).toEqual({ position: { x: 100, y: 200 } });
});
it("includes customized_name when present in metadata", () => {
const node = createTestNode({
id: "1",
data: {
metadata: { customized_name: "My Custom Name" },
},
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.metadata).toHaveProperty(
"customized_name",
"My Custom Name",
);
});
it("includes credentials_optional when present in metadata", () => {
const node = createTestNode({
id: "1",
data: {
metadata: { credentials_optional: true },
},
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.metadata).toHaveProperty("credentials_optional", true);
});
it("prunes empty values from hardcodedValues", () => {
const node = createTestNode({
id: "1",
data: {
hardcodedValues: { filled: "value", empty: "" },
},
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.input_default).toEqual({ filled: "value" });
expect(backend.input_default).not.toHaveProperty("empty");
});
});
describe("getBackendNodes", () => {
it("converts all nodes to backend format", () => {
useNodeStore
.getState()
.setNodes([
createTestNode({ id: "1", position: { x: 0, y: 0 } }),
createTestNode({ id: "2", position: { x: 100, y: 100 } }),
]);
const backendNodes = useNodeStore.getState().getBackendNodes();
expect(backendNodes).toHaveLength(2);
expect(backendNodes[0].id).toBe("1");
expect(backendNodes[1].id).toBe("2");
});
});
describe("node status", () => {
it("returns undefined for a node with no status", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().getNodeStatus("1")).toBeUndefined();
});
it("updates node status", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
expect(useNodeStore.getState().getNodeStatus("1")).toBe("RUNNING");
useNodeStore.getState().updateNodeStatus("1", "COMPLETED");
expect(useNodeStore.getState().getNodeStatus("1")).toBe("COMPLETED");
});
it("cleans all node statuses", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
useNodeStore.getState().updateNodeStatus("2", "COMPLETED");
useNodeStore.getState().cleanNodesStatuses();
expect(useNodeStore.getState().getNodeStatus("1")).toBeUndefined();
expect(useNodeStore.getState().getNodeStatus("2")).toBeUndefined();
});
it("updating status for non-existent node does not crash", () => {
useNodeStore.getState().updateNodeStatus("nonexistent", "RUNNING");
expect(
useNodeStore.getState().getNodeStatus("nonexistent"),
).toBeUndefined();
});
});
describe("execution result tracking", () => {
it("returns empty array for node with no results", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().getNodeExecutionResults("1")).toEqual([]);
});
it("tracks a single execution result", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
const result = createExecutionResult({ node_id: "1" });
useNodeStore.getState().updateNodeExecutionResult("1", result);
const results = useNodeStore.getState().getNodeExecutionResults("1");
expect(results).toHaveLength(1);
expect(results[0].node_exec_id).toBe("exec-1");
});
it("accumulates multiple execution results", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "val1" },
output_data: { key: ["out1"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-2",
input_data: { key: "val2" },
output_data: { key: ["out2"] },
}),
);
expect(useNodeStore.getState().getNodeExecutionResults("1")).toHaveLength(
2,
);
});
it("updates latest input/output data", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "first" },
output_data: { key: ["first_out"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-2",
input_data: { key: "second" },
output_data: { key: ["second_out"] },
}),
);
expect(useNodeStore.getState().getLatestNodeInputData("1")).toEqual({
key: "second",
});
expect(useNodeStore.getState().getLatestNodeOutputData("1")).toEqual({
key: ["second_out"],
});
});
it("accumulates input/output data across results", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "val1" },
output_data: { key: ["out1"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-2",
input_data: { key: "val2" },
output_data: { key: ["out2"] },
}),
);
const accInput = useNodeStore.getState().getAccumulatedNodeInputData("1");
expect(accInput.key).toEqual(["val1", "val2"]);
const accOutput = useNodeStore
.getState()
.getAccumulatedNodeOutputData("1");
expect(accOutput.key).toEqual(["out1", "out2"]);
});
it("deduplicates execution results by node_exec_id", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "original" },
output_data: { key: ["original_out"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "updated" },
output_data: { key: ["updated_out"] },
}),
);
const results = useNodeStore.getState().getNodeExecutionResults("1");
expect(results).toHaveLength(1);
expect(results[0].input_data).toEqual({ key: "updated" });
});
it("returns the latest execution result", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore
.getState()
.updateNodeExecutionResult(
"1",
createExecutionResult({ node_exec_id: "exec-1" }),
);
useNodeStore
.getState()
.updateNodeExecutionResult(
"1",
createExecutionResult({ node_exec_id: "exec-2" }),
);
const latest = useNodeStore.getState().getLatestNodeExecutionResult("1");
expect(latest?.node_exec_id).toBe("exec-2");
});
it("returns undefined for latest result on unknown node", () => {
expect(
useNodeStore.getState().getLatestNodeExecutionResult("unknown"),
).toBeUndefined();
});
it("clears all execution results", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore
.getState()
.updateNodeExecutionResult(
"1",
createExecutionResult({ node_exec_id: "exec-1" }),
);
useNodeStore
.getState()
.updateNodeExecutionResult(
"2",
createExecutionResult({ node_exec_id: "exec-2" }),
);
useNodeStore.getState().clearAllNodeExecutionResults();
expect(useNodeStore.getState().getNodeExecutionResults("1")).toEqual([]);
expect(useNodeStore.getState().getNodeExecutionResults("2")).toEqual([]);
expect(
useNodeStore.getState().getLatestNodeInputData("1"),
).toBeUndefined();
expect(
useNodeStore.getState().getLatestNodeOutputData("1"),
).toBeUndefined();
expect(useNodeStore.getState().getAccumulatedNodeInputData("1")).toEqual(
{},
);
expect(useNodeStore.getState().getAccumulatedNodeOutputData("1")).toEqual(
{},
);
});
it("returns empty object for accumulated data on unknown node", () => {
expect(
useNodeStore.getState().getAccumulatedNodeInputData("unknown"),
).toEqual({});
expect(
useNodeStore.getState().getAccumulatedNodeOutputData("unknown"),
).toEqual({});
});
});
describe("getNodeBlockUIType", () => {
it("returns the node UI type", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
uiType: BlockUIType.INPUT,
},
}),
);
expect(useNodeStore.getState().getNodeBlockUIType("1")).toBe(
BlockUIType.INPUT,
);
});
it("defaults to STANDARD for unknown node IDs", () => {
expect(useNodeStore.getState().getNodeBlockUIType("unknown")).toBe(
BlockUIType.STANDARD,
);
});
});
describe("hasWebhookNodes", () => {
it("returns false when there are no webhook nodes", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().hasWebhookNodes()).toBe(false);
});
it("returns true when a WEBHOOK node exists", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
uiType: BlockUIType.WEBHOOK,
},
}),
);
expect(useNodeStore.getState().hasWebhookNodes()).toBe(true);
});
it("returns true when a WEBHOOK_MANUAL node exists", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
uiType: BlockUIType.WEBHOOK_MANUAL,
},
}),
);
expect(useNodeStore.getState().hasWebhookNodes()).toBe(true);
});
});
describe("node errors", () => {
it("returns undefined for a node with no errors", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
});
it("sets and retrieves node errors", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
const errors = { field1: "required", field2: "invalid" };
useNodeStore.getState().updateNodeErrors("1", errors);
expect(useNodeStore.getState().getNodeErrors("1")).toEqual(errors);
});
it("clears errors for a specific node", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore.getState().updateNodeErrors("1", { f: "err" });
useNodeStore.getState().updateNodeErrors("2", { g: "err2" });
useNodeStore.getState().clearNodeErrors("1");
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
expect(useNodeStore.getState().getNodeErrors("2")).toEqual({ g: "err2" });
});
it("clears all node errors", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore.getState().updateNodeErrors("1", { a: "err1" });
useNodeStore.getState().updateNodeErrors("2", { b: "err2" });
useNodeStore.getState().clearAllNodeErrors();
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
expect(useNodeStore.getState().getNodeErrors("2")).toBeUndefined();
});
it("sets errors by backend ID matching node id", () => {
useNodeStore.getState().addNode(createTestNode({ id: "backend-1" }));
useNodeStore
.getState()
.setNodeErrorsForBackendId("backend-1", { x: "error" });
expect(useNodeStore.getState().getNodeErrors("backend-1")).toEqual({
x: "error",
});
});
});
describe("getHardCodedValues", () => {
it("returns hardcoded values for a node", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
hardcodedValues: { key: "value" },
},
}),
);
expect(useNodeStore.getState().getHardCodedValues("1")).toEqual({
key: "value",
});
});
it("returns empty object for unknown node", () => {
expect(useNodeStore.getState().getHardCodedValues("unknown")).toEqual({});
});
});
describe("credentials optional", () => {
it("sets credentials_optional in node metadata", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().setCredentialsOptional("1", true);
const node = useNodeStore.getState().nodes[0];
expect(node.data.metadata?.credentials_optional).toBe(true);
});
});
describe("resolution mode", () => {
it("defaults to not in resolution mode", () => {
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
});
it("enters and exits resolution mode", () => {
useNodeStore.getState().setNodeResolutionMode("1", true);
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(true);
useNodeStore.getState().setNodeResolutionMode("1", false);
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
});
it("tracks broken edge IDs", () => {
useNodeStore.getState().setBrokenEdgeIDs("node-1", ["edge-1", "edge-2"]);
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(true);
expect(useNodeStore.getState().isEdgeBroken("edge-2")).toBe(true);
expect(useNodeStore.getState().isEdgeBroken("edge-3")).toBe(false);
});
it("removes individual broken edge IDs", () => {
useNodeStore.getState().setBrokenEdgeIDs("node-1", ["edge-1", "edge-2"]);
useNodeStore.getState().removeBrokenEdgeID("node-1", "edge-1");
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
expect(useNodeStore.getState().isEdgeBroken("edge-2")).toBe(true);
});
it("clears all resolution state", () => {
useNodeStore.getState().setNodeResolutionMode("1", true);
useNodeStore.getState().setBrokenEdgeIDs("1", ["edge-1"]);
useNodeStore.getState().clearResolutionState();
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
});
it("cleans up broken edges when exiting resolution mode", () => {
useNodeStore.getState().setNodeResolutionMode("1", true);
useNodeStore.getState().setBrokenEdgeIDs("1", ["edge-1"]);
useNodeStore.getState().setNodeResolutionMode("1", false);
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
});
});
describe("edge cases", () => {
it("handles updating data on a non-existent node gracefully", () => {
useNodeStore
.getState()
.updateNodeData("nonexistent", { title: "New Title" });
expect(useNodeStore.getState().nodes).toHaveLength(0);
});
it("handles removing a non-existent node gracefully", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore
.getState()
.onNodesChange([{ type: "remove", id: "nonexistent" }]);
expect(useNodeStore.getState().nodes).toHaveLength(1);
});
it("handles duplicate node IDs in addNodes", () => {
useNodeStore.getState().addNodes([
createTestNode({
id: "1",
data: { title: "First" },
}),
createTestNode({
id: "1",
data: { title: "Second" },
}),
]);
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
expect(nodes[0].data.title).toBe("First");
expect(nodes[1].data.title).toBe("Second");
});
it("updating node status mid-execution preserves other data", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
title: "My Node",
hardcodedValues: { key: "val" },
},
}),
);
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
const node = useNodeStore.getState().nodes[0];
expect(node.data.status).toBe("RUNNING");
expect(node.data.title).toBe("My Node");
expect(node.data.hardcodedValues).toEqual({ key: "val" });
});
it("execution result for non-existent node does not add it", () => {
useNodeStore
.getState()
.updateNodeExecutionResult(
"nonexistent",
createExecutionResult({ node_exec_id: "exec-1" }),
);
expect(useNodeStore.getState().nodes).toHaveLength(0);
expect(
useNodeStore.getState().getNodeExecutionResults("nonexistent"),
).toEqual([]);
});
it("getBackendNodes returns empty array when no nodes exist", () => {
expect(useNodeStore.getState().getBackendNodes()).toEqual([]);
});
});
});

View File

@@ -1,567 +0,0 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import { BlockUIType } from "../components/types";
// ---- Mocks ----
const mockGetViewport = vi.fn(() => ({ x: 0, y: 0, zoom: 1 }));
vi.mock("@xyflow/react", async () => {
const actual = await vi.importActual("@xyflow/react");
return {
...actual,
useReactFlow: vi.fn(() => ({
getViewport: mockGetViewport,
})),
};
});
const mockToast = vi.fn();
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: vi.fn(() => ({ toast: mockToast })),
}));
let uuidCounter = 0;
vi.mock("uuid", () => ({
v4: vi.fn(() => `new-uuid-${++uuidCounter}`),
}));
// Mock navigator.clipboard
const mockWriteText = vi.fn(() => Promise.resolve());
const mockReadText = vi.fn(() => Promise.resolve(""));
Object.defineProperty(navigator, "clipboard", {
value: {
writeText: mockWriteText,
readText: mockReadText,
},
writable: true,
configurable: true,
});
// Mock window.innerWidth / innerHeight for viewport centering calculations
Object.defineProperty(window, "innerWidth", { value: 1000, writable: true });
Object.defineProperty(window, "innerHeight", { value: 800, writable: true });
import { useCopyPaste } from "../components/FlowEditor/Flow/useCopyPaste";
import { useNodeStore } from "../stores/nodeStore";
import { useEdgeStore } from "../stores/edgeStore";
import { useHistoryStore } from "../stores/historyStore";
import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
const CLIPBOARD_PREFIX = "autogpt-flow-data:";
function createTestNode(
id: string,
overrides: Partial<CustomNode> = {},
): CustomNode {
return {
id,
type: "custom",
position: overrides.position ?? { x: 100, y: 200 },
selected: overrides.selected,
data: {
hardcodedValues: {},
title: `Node ${id}`,
description: "test node",
inputSchema: {},
outputSchema: {},
uiType: BlockUIType.STANDARD,
block_id: `block-${id}`,
costs: [],
categories: [],
...overrides.data,
},
} as CustomNode;
}
function createTestEdge(
id: string,
source: string,
target: string,
sourceHandle = "out",
targetHandle = "in",
): CustomEdge {
return {
id,
source,
target,
sourceHandle,
targetHandle,
} as CustomEdge;
}
function makeCopyEvent(): KeyboardEvent {
return new KeyboardEvent("keydown", {
key: "c",
ctrlKey: true,
bubbles: true,
});
}
function makePasteEvent(): KeyboardEvent {
return new KeyboardEvent("keydown", {
key: "v",
ctrlKey: true,
bubbles: true,
});
}
function clipboardPayload(nodes: CustomNode[], edges: CustomEdge[]): string {
return `${CLIPBOARD_PREFIX}${JSON.stringify({ nodes, edges })}`;
}
describe("useCopyPaste", () => {
beforeEach(() => {
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
useEdgeStore.setState({ edges: [] });
useHistoryStore.getState().clear();
mockWriteText.mockClear();
mockReadText.mockClear();
mockToast.mockClear();
mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 });
uuidCounter = 0;
// Ensure no input element is focused
if (document.activeElement && document.activeElement !== document.body) {
(document.activeElement as HTMLElement).blur();
}
});
describe("copy (Ctrl+C)", () => {
it("copies a single selected node to clipboard with prefix", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const written = (mockWriteText.mock.calls as string[][])[0][0];
expect(written.startsWith(CLIPBOARD_PREFIX)).toBe(true);
const parsed = JSON.parse(written.slice(CLIPBOARD_PREFIX.length));
expect(parsed.nodes).toHaveLength(1);
expect(parsed.nodes[0].id).toBe("1");
expect(parsed.edges).toHaveLength(0);
});
it("shows a success toast after copying", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Copied successfully",
}),
);
});
});
it("copies multiple connected nodes and preserves internal edges", async () => {
const nodeA = createTestNode("a", { selected: true });
const nodeB = createTestNode("b", { selected: true });
const nodeC = createTestNode("c", { selected: false });
useNodeStore.setState({ nodes: [nodeA, nodeB, nodeC] });
useEdgeStore.setState({
edges: [
createTestEdge("e-ab", "a", "b"),
createTestEdge("e-bc", "b", "c"),
],
});
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const parsed = JSON.parse(
(mockWriteText.mock.calls as string[][])[0][0].slice(
CLIPBOARD_PREFIX.length,
),
);
expect(parsed.nodes).toHaveLength(2);
expect(parsed.edges).toHaveLength(1);
expect(parsed.edges[0].id).toBe("e-ab");
});
it("drops external edges where one endpoint is not selected", async () => {
const nodeA = createTestNode("a", { selected: true });
const nodeB = createTestNode("b", { selected: false });
useNodeStore.setState({ nodes: [nodeA, nodeB] });
useEdgeStore.setState({
edges: [createTestEdge("e-ab", "a", "b")],
});
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const parsed = JSON.parse(
(mockWriteText.mock.calls as string[][])[0][0].slice(
CLIPBOARD_PREFIX.length,
),
);
expect(parsed.nodes).toHaveLength(1);
expect(parsed.edges).toHaveLength(0);
});
it("copies nothing when no nodes are selected", async () => {
const node = createTestNode("1", { selected: false });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const parsed = JSON.parse(
(mockWriteText.mock.calls as string[][])[0][0].slice(
CLIPBOARD_PREFIX.length,
),
);
expect(parsed.nodes).toHaveLength(0);
expect(parsed.edges).toHaveLength(0);
});
});
describe("paste (Ctrl+V)", () => {
it("creates new nodes with new UUIDs", async () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 200 },
});
mockReadText.mockResolvedValue(clipboardPayload([node], []));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
});
const { nodes } = useNodeStore.getState();
expect(nodes[0].id).toBe("new-uuid-1");
expect(nodes[0].id).not.toBe("orig");
});
it("centers pasted nodes in the current viewport", async () => {
// Viewport at origin, zoom 1 => center = (500, 400)
mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 });
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 100 },
});
mockReadText.mockResolvedValue(clipboardPayload([node], []));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
});
const { nodes } = useNodeStore.getState();
// Single node: center of bounds = (100, 100)
// Viewport center = (500, 400)
// Offset = (400, 300)
// New position = (100 + 400, 100 + 300) = (500, 400)
expect(nodes[0].position).toEqual({ x: 500, y: 400 });
});
it("deselects existing nodes and selects pasted nodes", async () => {
const existingNode = createTestNode("existing", {
selected: true,
position: { x: 0, y: 0 },
});
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
const nodeToPaste = createTestNode("paste-me", {
selected: false,
position: { x: 100, y: 100 },
});
mockReadText.mockResolvedValue(clipboardPayload([nodeToPaste], []));
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
});
const { nodes } = useNodeStore.getState();
const originalNode = nodes.find((n) => n.id === "existing");
const pastedNode = nodes.find((n) => n.id !== "existing");
expect(originalNode!.selected).toBe(false);
expect(pastedNode!.selected).toBe(true);
});
it("remaps edge source/target IDs to newly created node IDs", async () => {
const nodeA = createTestNode("a", {
selected: true,
position: { x: 0, y: 0 },
});
const nodeB = createTestNode("b", {
selected: true,
position: { x: 200, y: 0 },
});
const edge = createTestEdge("e-ab", "a", "b", "output", "input");
mockReadText.mockResolvedValue(clipboardPayload([nodeA, nodeB], [edge]));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
useEdgeStore.setState({ edges: [] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
});
// Wait for edges to be added too
await vi.waitFor(() => {
const { edges } = useEdgeStore.getState();
expect(edges).toHaveLength(1);
});
const { edges } = useEdgeStore.getState();
const newEdge = edges[0];
// Edge source/target should be remapped to new UUIDs, not "a"/"b"
expect(newEdge.source).not.toBe("a");
expect(newEdge.target).not.toBe("b");
expect(newEdge.source).toBe("new-uuid-1");
expect(newEdge.target).toBe("new-uuid-2");
expect(newEdge.sourceHandle).toBe("output");
expect(newEdge.targetHandle).toBe("input");
});
it("does nothing when clipboard does not have the expected prefix", async () => {
mockReadText.mockResolvedValue("some random text");
const existingNode = createTestNode("1", { position: { x: 0, y: 0 } });
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
// Give async operations time to settle
await vi.waitFor(() => {
expect(mockReadText).toHaveBeenCalled();
});
// Ensure no state changes happen after clipboard read
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
});
it("does nothing when clipboard is empty", async () => {
mockReadText.mockResolvedValue("");
const existingNode = createTestNode("1", { position: { x: 0, y: 0 } });
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
expect(mockReadText).toHaveBeenCalled();
});
// Ensure no state changes happen after clipboard read
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
});
});
describe("input field focus guard", () => {
it("ignores Ctrl+C when an input element is focused", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const input = document.createElement("input");
document.body.appendChild(input);
input.focus();
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
// Clipboard write should NOT be called
expect(mockWriteText).not.toHaveBeenCalled();
document.body.removeChild(input);
});
it("ignores Ctrl+V when a textarea element is focused", async () => {
mockReadText.mockResolvedValue(
clipboardPayload(
[createTestNode("a", { position: { x: 0, y: 0 } })],
[],
),
);
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const textarea = document.createElement("textarea");
document.body.appendChild(textarea);
textarea.focus();
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
expect(mockReadText).not.toHaveBeenCalled();
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(0);
document.body.removeChild(textarea);
});
it("ignores keypresses when a contenteditable element is focused", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const div = document.createElement("div");
div.setAttribute("contenteditable", "true");
document.body.appendChild(div);
div.focus();
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
expect(mockWriteText).not.toHaveBeenCalled();
document.body.removeChild(div);
});
});
describe("meta key support (macOS)", () => {
it("handles Cmd+C (metaKey) the same as Ctrl+C", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
const metaCopyEvent = new KeyboardEvent("keydown", {
key: "c",
metaKey: true,
bubbles: true,
});
act(() => {
result.current(metaCopyEvent);
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
});
it("handles Cmd+V (metaKey) the same as Ctrl+V", async () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 0, y: 0 },
});
mockReadText.mockResolvedValue(clipboardPayload([node], []));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
const metaPasteEvent = new KeyboardEvent("keydown", {
key: "v",
metaKey: true,
bubbles: true,
});
act(() => {
result.current(metaPasteEvent);
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
});
});
});
});

View File

@@ -1,134 +0,0 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
const mockScreenToFlowPosition = vi.fn((pos: { x: number; y: number }) => pos);
const mockFitView = vi.fn();
vi.mock("@xyflow/react", async () => {
const actual = await vi.importActual("@xyflow/react");
return {
...actual,
useReactFlow: () => ({
screenToFlowPosition: mockScreenToFlowPosition,
fitView: mockFitView,
}),
};
});
const mockSetQueryStates = vi.fn();
let mockQueryStateValues: {
flowID: string | null;
flowVersion: number | null;
flowExecutionID: string | null;
} = {
flowID: null,
flowVersion: null,
flowExecutionID: null,
};
vi.mock("nuqs", () => ({
parseAsString: {},
parseAsInteger: {},
useQueryStates: vi.fn(() => [mockQueryStateValues, mockSetQueryStates]),
}));
let mockGraphLoading = false;
let mockBlocksLoading = false;
vi.mock("@/app/api/__generated__/endpoints/graphs/graphs", () => ({
useGetV1GetSpecificGraph: vi.fn(() => ({
data: undefined,
isLoading: mockGraphLoading,
})),
useGetV1GetExecutionDetails: vi.fn(() => ({
data: undefined,
})),
useGetV1ListUserGraphs: vi.fn(() => ({
data: undefined,
})),
}));
vi.mock("@/app/api/__generated__/endpoints/default/default", () => ({
useGetV2GetSpecificBlocks: vi.fn(() => ({
data: undefined,
isLoading: mockBlocksLoading,
})),
}));
vi.mock("@/app/api/helpers", () => ({
okData: (res: { data: unknown }) => res?.data,
}));
vi.mock("../components/helper", () => ({
convertNodesPlusBlockInfoIntoCustomNodes: vi.fn(),
}));
describe("useFlow", () => {
beforeEach(() => {
vi.clearAllMocks();
vi.useFakeTimers({ shouldAdvanceTime: true });
mockGraphLoading = false;
mockBlocksLoading = false;
mockQueryStateValues = {
flowID: null,
flowVersion: null,
flowExecutionID: null,
};
});
afterEach(() => {
vi.useRealTimers();
});
describe("loading states", () => {
it("returns isFlowContentLoading true when graph is loading", async () => {
mockGraphLoading = true;
mockQueryStateValues = {
flowID: "test-flow",
flowVersion: 1,
flowExecutionID: null,
};
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isFlowContentLoading).toBe(true);
});
it("returns isFlowContentLoading true when blocks are loading", async () => {
mockBlocksLoading = true;
mockQueryStateValues = {
flowID: "test-flow",
flowVersion: 1,
flowExecutionID: null,
};
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isFlowContentLoading).toBe(true);
});
it("returns isFlowContentLoading false when neither is loading", async () => {
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isFlowContentLoading).toBe(false);
});
});
describe("initial load completion", () => {
it("marks initial load complete for new flows without flowID", async () => {
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isInitialLoadComplete).toBe(false);
await act(async () => {
vi.advanceTimersByTime(300);
});
expect(result.current.isInitialLoadComplete).toBe(true);
});
});
});

View File

@@ -1,14 +1,8 @@
"use client";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { SidebarProvider } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
import { UploadSimple } from "@phosphor-icons/react";
import { useCallback, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
@@ -89,10 +83,9 @@ export function CopilotPage() {
handleDrawerOpenChange,
handleSelectSession,
handleNewChat,
// Delete functionality
// Delete functionality (available via ChatSidebar context menu on all viewports)
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
} = useCopilotPage();
@@ -148,38 +141,6 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
headerSlot={
isMobile && sessionId ? (
<div className="flex justify-end">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
className="rounded p-1.5 hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-5 w-5 text-neutral-600" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={() => {
const session = sessions.find(
(s) => s.id === sessionId,
);
if (session) {
handleDeleteClick(session.id, session.title);
}
}}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
) : undefined
}
/>
</div>
</div>

View File

@@ -2,7 +2,6 @@
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ReactNode } from "react";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
@@ -21,7 +20,6 @@ export interface ChatContainerProps {
onSend: (message: string, files?: File[]) => void | Promise<void>;
onStop: () => void;
isUploadingFiles?: boolean;
headerSlot?: ReactNode;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -40,7 +38,6 @@ export const ChatContainer = ({
onSend,
onStop,
isUploadingFiles,
headerSlot,
droppedFiles,
onDroppedFilesConsumed,
}: ChatContainerProps) => {
@@ -63,7 +60,6 @@ export const ChatContainer = ({
status={status}
error={error}
isLoading={isLoadingSession}
headerSlot={headerSlot}
sessionID={sessionId}
/>
<motion.div

View File

@@ -1,4 +1,3 @@
import { useCopilotUIStore } from "@/app/(platform)/copilot/store";
import { ChangeEvent, FormEvent, useEffect, useState } from "react";
interface Args {
@@ -17,16 +16,6 @@ export function useChatInput({
}: Args) {
const [value, setValue] = useState("");
const [isSending, setIsSending] = useState(false);
const { initialPrompt, setInitialPrompt } = useCopilotUIStore();
useEffect(
function consumeInitialPrompt() {
if (!initialPrompt) return;
setValue((prev) => (prev.length === 0 ? initialPrompt : prev));
setInitialPrompt(null);
},
[initialPrompt, setInitialPrompt],
);
useEffect(
function focusOnMount() {

View File

@@ -30,7 +30,6 @@ interface Props {
status: string;
error: Error | undefined;
isLoading: boolean;
headerSlot?: React.ReactNode;
sessionID?: string | null;
}
@@ -102,7 +101,6 @@ export function ChatMessagesContainer({
status,
error,
isLoading,
headerSlot,
sessionID,
}: Props) {
const lastMessage = messages[messages.length - 1];
@@ -135,7 +133,6 @@ export function ChatMessagesContainer({
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
{headerSlot}
{isLoading && messages.length === 0 && (
<div
className="flex flex-1 items-center justify-center"

View File

@@ -37,6 +37,7 @@ import { useCopilotUIStore } from "../../store";
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
import { PulseLoader } from "../PulseLoader/PulseLoader";
import { UsageLimits } from "../UsageLimits/UsageLimits";
export function ChatSidebar() {
const { state } = useSidebar();
@@ -256,11 +257,10 @@ export function ChatSidebar() {
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-5 flex items-center gap-1">
<div className="flex items-center">
<UsageLimits />
<NotificationToggle />
<div className="relative left-1">
<SidebarTrigger />
</div>
<SidebarTrigger />
</div>
</div>
{sessionId ? (

View File

@@ -7,6 +7,7 @@ import {
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { toast } from "@/components/molecules/Toast/use-toast";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
import { useCopilotUIStore } from "../../../../store";
@@ -48,10 +49,7 @@ export function NotificationToggle() {
return (
<Popover>
<PopoverTrigger asChild>
<button
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
aria-label="Notification settings"
>
<Button variant="ghost" size="icon" aria-label="Notification settings">
{!isNotificationsEnabled ? (
<BellSlash className="!size-5" />
) : isSoundEnabled ? (
@@ -59,7 +57,7 @@ export function NotificationToggle() {
) : (
<Bell className="!size-5" />
)}
</button>
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-56 p-3">
<div className="flex flex-col gap-3">

View File

@@ -0,0 +1,38 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { Button } from "@/components/ui/button";
import { ChartBar } from "@phosphor-icons/react";
import { UsagePanelContent } from "./UsagePanelContent";
export { UsagePanelContent, formatResetTime } from "./UsagePanelContent";
export function UsageLimits() {
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Usage limits">
<ChartBar className="!size-5" weight="light" />
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-64 p-3">
<UsagePanelContent usage={usage} />
</PopoverContent>
</Popover>
);
}

View File

@@ -0,0 +1,118 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import Link from "next/link";
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
function UsageBar({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-1">
<div className="flex items-baseline justify-between">
<span className="text-xs font-medium text-neutral-700">{label}</span>
<span className="text-[11px] tabular-nums text-neutral-500">
{percentLabel}
</span>
</div>
<div className="text-[10px] text-neutral-400">
Resets {formatResetTime(resetsAt)}
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
</div>
);
}
export function UsagePanelContent({
usage,
showBillingLink = true,
}: {
usage: CoPilotUsageStatus;
showBillingLink?: boolean;
}) {
const hasDailyLimit = usage.daily.limit > 0;
const hasWeeklyLimit = usage.weekly.limit > 0;
if (!hasDailyLimit && !hasWeeklyLimit) {
return (
<div className="text-xs text-neutral-500">No usage limits configured</div>
);
}
return (
<div className="flex flex-col gap-3">
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
{hasDailyLimit && (
<UsageBar
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{hasWeeklyLimit && (
<UsageBar
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
{showBillingLink && (
<Link
href="/profile/credits"
className="text-[11px] text-blue-600 hover:underline"
>
Learn more about usage limits
</Link>
)}
</div>
);
}

View File

@@ -0,0 +1,124 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsageLimits } from "../UsageLimits";
// Mock the generated Orval hook
const mockUseGetV2GetCopilotUsage = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useGetV2GetCopilotUsage: (opts: unknown) => mockUseGetV2GetCopilotUsage(opts),
}));
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
vi.mock("@/components/molecules/Popover/Popover", () => ({
Popover: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
afterEach(() => {
cleanup();
mockUseGetV2GetCopilotUsage.mockReset();
});
function makeUsage({
dailyUsed = 500,
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
}: {
dailyUsed?: number;
dailyLimit?: number;
weeklyUsed?: number;
weeklyLimit?: number;
} = {}) {
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
};
}
describe("UsageLimits", () => {
it("renders nothing while loading", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: undefined,
isLoading: true,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders nothing when no limits are configured", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
isLoading: false,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders the usage button when limits exist", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
});
it("displays daily and weekly usage percentages", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("50% used")).toBeDefined();
expect(screen.getByText("Today")).toBeDefined();
expect(screen.getByText("This week")).toBeDefined();
expect(screen.getByText("Usage limits")).toBeDefined();
});
it("shows only weekly bar when daily limit is 0", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({
dailyLimit: 0,
weeklyUsed: 25000,
weeklyLimit: 50000,
}),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("This week")).toBeDefined();
expect(screen.queryByText("Today")).toBeNull();
});
it("caps percentage at 100% when over limit", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("100% used")).toBeDefined();
});
it("shows learn more link to credits page", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
const link = screen.getByText("Learn more about usage limits");
expect(link).toBeDefined();
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
});
});

View File

@@ -7,10 +7,6 @@ export interface DeleteTarget {
}
interface CopilotUIState {
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
initialPrompt: string | null;
setInitialPrompt: (prompt: string | null) => void;
sessionToDelete: DeleteTarget | null;
setSessionToDelete: (target: DeleteTarget | null) => void;
@@ -35,9 +31,6 @@ interface CopilotUIState {
}
export const useCopilotUIStore = create<CopilotUIState>((set) => ({
initialPrompt: null,
setInitialPrompt: (prompt) => set({ initialPrompt: prompt }),
sessionToDelete: null,
setSessionToDelete: (target) => set({ sessionToDelete: target }),

View File

@@ -19,42 +19,6 @@ import { useCopilotStream } from "./useCopilotStream";
const TITLE_POLL_INTERVAL_MS = 2_000;
const TITLE_POLL_MAX_ATTEMPTS = 5;
/**
* Extract a prompt from the URL hash fragment.
* Supports: /copilot#prompt=URL-encoded-text
* Optionally auto-submits if ?autosubmit=true is in the query string.
* Returns null if no prompt is present.
*/
function extractPromptFromUrl(): {
prompt: string;
autosubmit: boolean;
} | null {
if (typeof window === "undefined") return null;
const hash = window.location.hash;
if (!hash) return null;
const hashParams = new URLSearchParams(hash.slice(1));
const prompt = hashParams.get("prompt");
if (!prompt || !prompt.trim()) return null;
const searchParams = new URLSearchParams(window.location.search);
const autosubmit = searchParams.get("autosubmit") === "true";
// Clean up hash + autosubmit param only (preserve other query params)
const cleanURL = new URL(window.location.href);
cleanURL.hash = "";
cleanURL.searchParams.delete("autosubmit");
window.history.replaceState(
null,
"",
`${cleanURL.pathname}${cleanURL.search}`,
);
return { prompt: prompt.trim(), autosubmit };
}
interface UploadedFile {
file_id: string;
name: string;
@@ -163,28 +127,6 @@ export function useCopilotPage() {
}
}, [sessionId, pendingMessage, sendMessage]);
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
const { setInitialPrompt } = useCopilotUIStore();
const hasProcessedUrlPrompt = useRef(false);
useEffect(() => {
if (hasProcessedUrlPrompt.current) return;
const urlPrompt = extractPromptFromUrl();
if (!urlPrompt) return;
hasProcessedUrlPrompt.current = true;
if (urlPrompt.autosubmit) {
setPendingMessage(urlPrompt.prompt);
void createSession().catch(() => {
setPendingMessage(null);
setInitialPrompt(urlPrompt.prompt);
});
} else {
setInitialPrompt(urlPrompt.prompt);
}
}, [createSession, setInitialPrompt]);
async function uploadFiles(
files: File[],
sid: string,

View File

@@ -1,4 +1,5 @@
import {
getGetV2GetCopilotUsageQueryKey,
getGetV2GetSessionQueryKey,
postV2CancelSessionTask,
} from "@/app/api/__generated__/endpoints/chat/chat";
@@ -177,12 +178,41 @@ export function useCopilotStream({
onError: (error) => {
if (!sessionId) return;
// Detect rate limit (429) responses and show reset time to the user.
// The SDK throws a plain Error whose message is the raw response body
// (FastAPI returns {"detail": "...usage limit..."} for 429s).
let errorDetail: string = error.message;
try {
const parsed = JSON.parse(error.message) as unknown;
if (
typeof parsed === "object" &&
parsed !== null &&
"detail" in parsed &&
typeof (parsed as { detail: unknown }).detail === "string"
) {
errorDetail = (parsed as { detail: string }).detail;
}
} catch {
// Not JSON — use message as-is
}
const isRateLimited = errorDetail.toLowerCase().includes("usage limit");
if (isRateLimited) {
toast({
title: "Usage limit reached",
description:
errorDetail ||
"You've reached your usage limit. Please try again later.",
variant: "destructive",
});
return;
}
// Detect authentication failures (from getAuthHeaders or 401 responses)
const isAuthError =
error.message.includes("Authentication failed") ||
error.message.includes("Unauthorized") ||
error.message.includes("Not authenticated") ||
error.message.toLowerCase().includes("401");
errorDetail.includes("Authentication failed") ||
errorDetail.includes("Unauthorized") ||
errorDetail.includes("Not authenticated") ||
errorDetail.toLowerCase().includes("401");
if (isAuthError) {
toast({
title: "Authentication error",
@@ -307,6 +337,9 @@ export function useCopilotStream({
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
queryClient.invalidateQueries({
queryKey: getGetV2GetCopilotUsageQueryKey(),
});
if (status === "ready") {
reconnectAttemptsRef.current = 0;
hasShownDisconnectToast.current = false;

View File

@@ -11,6 +11,9 @@ import {
import { RefundModal } from "./RefundModal";
import { CreditTransaction } from "@/lib/autogpt-server-api";
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
import {
Table,
@@ -21,6 +24,32 @@ import {
TableRow,
} from "@/components/__legacy__/ui/table";
function CoPilotUsageSection() {
const router = useRouter();
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<div className="my-6 space-y-4">
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
<div className="rounded-lg border border-neutral-200 p-4">
<UsagePanelContent usage={usage} showBillingLink={false} />
</div>
<Button className="w-full" onClick={() => router.push("/copilot")}>
Open CoPilot
</Button>
</div>
);
}
export default function CreditsPage() {
const api = useBackendAPI();
const {
@@ -237,11 +266,13 @@ export default function CreditsPage() {
</Button>
)}
</form>
{/* CoPilot Usage Limits */}
<CoPilotUsageSection />
</div>
<div className="my-6 space-y-4">
{/* Payment Portal */}
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
<p className="text-neutral-600">
You can manage your cards and see your payment history in the

View File

@@ -1267,7 +1267,7 @@
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1382,6 +1382,28 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/usage": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Copilot Usage",
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
"operationId": "getV2GetCopilotUsage",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -8455,6 +8477,16 @@
"title": "ClarifyingQuestion",
"description": "A question that needs user clarification."
},
"CoPilotUsageStatus": {
"properties": {
"daily": { "$ref": "#/components/schemas/UsageWindow" },
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
},
"type": "object",
"required": ["daily", "weekly"],
"title": "CoPilotUsageStatus",
"description": "Current usage status for a user across all windows."
},
"ContentType": {
"type": "string",
"enum": [
@@ -12190,6 +12222,16 @@
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
},
"total_prompt_tokens": {
"type": "integer",
"title": "Total Prompt Tokens",
"default": 0
},
"total_completion_tokens": {
"type": "integer",
"title": "Total Completion Tokens",
"default": 0
}
},
"type": "object",
@@ -14587,6 +14629,25 @@
"required": ["timezone"],
"title": "UpdateTimezoneRequest"
},
"UsageWindow": {
"properties": {
"used": { "type": "integer", "title": "Used" },
"limit": {
"type": "integer",
"title": "Limit",
"description": "Maximum tokens allowed in this window. 0 means unlimited."
},
"resets_at": {
"type": "string",
"format": "date-time",
"title": "Resets At"
}
},
"type": "object",
"required": ["used", "limit", "resets_at"],
"title": "UsageWindow",
"description": "Usage within a single time window."
},
"UserHistoryResponse": {
"properties": {
"history": {

View File

@@ -288,6 +288,7 @@ const SidebarTrigger = React.forwardRef<
ref={ref}
data-sidebar="trigger"
variant="ghost"
size="icon"
onClick={(event) => {
onClick?.(event);
toggleSidebar();

View File

@@ -1,343 +0,0 @@
# Workspace & Media File Architecture
This document describes the architecture for handling user files in AutoGPT Platform, covering persistent user storage (Workspace) and ephemeral media processing pipelines.
## Overview
The platform has two distinct file-handling layers:
| Layer | Purpose | Persistence | Scope |
|-------|---------|-------------|-------|
| **Workspace** | Long-term user file storage | Persistent (DB + GCS/local) | Per-user, session-scoped access |
| **Media Pipeline** | Ephemeral file processing for blocks | Temporary (local disk) | Per-execution |
## Database Models
### UserWorkspace
Represents a user's file storage space. Created on-demand (one per user).
```prisma
model UserWorkspace {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String @unique
Files UserWorkspaceFile[]
}
```
**Key points:**
- One workspace per user (enforced by `@unique` on `userId`)
- Created lazily via `get_or_create_workspace()`
- Uses upsert to handle race conditions
### UserWorkspaceFile
Represents a file stored in a user's workspace.
```prisma
model UserWorkspaceFile {
id String @id @default(uuid())
workspaceId String
name String // User-visible filename
path String // Virtual path (e.g., "/sessions/abc123/image.png")
storagePath String // Actual storage path (gcs://... or local://...)
mimeType String
sizeBytes BigInt
checksum String? // SHA256 for integrity
isDeleted Boolean @default(false)
deletedAt DateTime?
metadata Json @default("{}")
@@unique([workspaceId, path]) // Enforce unique paths within workspace
}
```
**Key points:**
- `path` is a virtual path for organizing files (not actual filesystem path)
- `storagePath` contains the actual GCS or local storage location
- Soft-delete pattern: `isDeleted` flag with `deletedAt` timestamp
- Path is modified on delete to free up the virtual path for reuse
---
## WorkspaceManager
**Location:** `backend/util/workspace.py`
High-level API for workspace file operations. Combines storage backend operations with database record management.
### Initialization
```python
from backend.util.workspace import WorkspaceManager
# Basic usage
manager = WorkspaceManager(user_id="user-123", workspace_id="ws-456")
# With session scoping (CoPilot sessions)
manager = WorkspaceManager(
user_id="user-123",
workspace_id="ws-456",
session_id="session-789"
)
```
### Session Scoping
When `session_id` is provided, files are isolated to `/sessions/{session_id}/`:
```python
# With session_id="abc123":
manager.write_file(content, "image.png")
# → stored at /sessions/abc123/image.png
# Cross-session access is explicit:
manager.read_file("/sessions/other-session/file.txt") # Works
```
**Why session scoping?**
- CoPilot conversations need file isolation
- Prevents file collisions between concurrent sessions
- Allows session cleanup without affecting other sessions
### Core Methods
| Method | Description |
|--------|-------------|
| `write_file(content, filename, path?, mime_type?, overwrite?)` | Write file to workspace |
| `read_file(path)` | Read file by virtual path |
| `read_file_by_id(file_id)` | Read file by ID |
| `list_files(path?, limit?, offset?, include_all_sessions?)` | List files |
| `delete_file(file_id)` | Soft-delete a file |
| `get_download_url(file_id, expires_in?)` | Get signed download URL |
| `get_file_info(file_id)` | Get file metadata |
| `get_file_info_by_path(path)` | Get file metadata by path |
| `get_file_count(path?, include_all_sessions?)` | Count files |
### Storage Backends
WorkspaceManager delegates to `WorkspaceStorageBackend`:
| Backend | When Used | Storage Path Format |
|---------|-----------|---------------------|
| `GCSWorkspaceStorage` | `media_gcs_bucket_name` is configured | `gcs://bucket/workspaces/{ws_id}/{file_id}/{filename}` |
| `LocalWorkspaceStorage` | No GCS bucket configured | `local://{ws_id}/{file_id}/{filename}` |
---
## store_media_file()
**Location:** `backend/util/file.py`
The media normalization pipeline. Handles various input types and normalizes them for processing or output.
### Purpose
Blocks receive files in many formats (URLs, data URIs, workspace references, local paths). `store_media_file()` normalizes these to a consistent format based on what the block needs.
### Input Types Handled
| Input Format | Example | How It's Processed |
|--------------|---------|-------------------|
| Data URI | `data:image/png;base64,iVBOR...` | Decoded, virus scanned, written locally |
| HTTP(S) URL | `https://example.com/image.png` | Downloaded, virus scanned, written locally |
| Workspace URI | `workspace://abc123` or `workspace:///path/to/file` | Read from workspace, virus scanned, written locally |
| Cloud path | `gcs://bucket/path` | Downloaded, virus scanned, written locally |
| Local path | `image.png` | Verified to exist in exec_file directory |
### Return Formats
The `return_format` parameter determines what you get back:
```python
from backend.util.file import store_media_file
# For local processing (ffmpeg, MoviePy, PIL)
local_path = await store_media_file(
file=input_file,
execution_context=ctx,
return_format="for_local_processing"
)
# Returns: "image.png" (relative path in exec_file dir)
# For external APIs (Replicate, OpenAI, etc.)
data_uri = await store_media_file(
file=input_file,
execution_context=ctx,
return_format="for_external_api"
)
# Returns: "data:image/png;base64,iVBOR..."
# For block output (adapts to execution context)
output = await store_media_file(
file=input_file,
execution_context=ctx,
return_format="for_block_output"
)
# In CoPilot: Returns "workspace://file-id#image/png"
# In graphs: Returns "data:image/png;base64,..."
```
### Execution Context
`store_media_file()` requires an `ExecutionContext` with:
- `graph_exec_id` - Required for temp file location
- `user_id` - Required for workspace access
- `workspace_id` - Optional; enables workspace features
- `session_id` - Optional; for session scoping in CoPilot
---
## Responsibility Boundaries
### Virus Scanning
| Component | Scans? | Notes |
|-----------|--------|-------|
| `store_media_file()` | ✅ Yes | Scans **all** content before writing to local disk |
| `WorkspaceManager.write_file()` | ✅ Yes | Scans content before persisting |
**Scanning happens at:**
1. `store_media_file()` — scans everything it downloads/decodes
2. `WorkspaceManager.write_file()` — scans before persistence
Tools like `WriteWorkspaceFileTool` don't need to scan because `WorkspaceManager.write_file()` handles it.
### Persistence
| Component | Persists To | Lifecycle |
|-----------|-------------|-----------|
| `store_media_file()` | Temp dir (`/tmp/exec_file/{exec_id}/`) | Cleaned after execution |
| `WorkspaceManager` | GCS or local storage + DB | Persistent until deleted |
**Automatic cleanup:** `clean_exec_files(graph_exec_id)` removes temp files after execution completes.
---
## Decision Tree: WorkspaceManager vs store_media_file
```text
┌─────────────────────────────────────────────────────┐
│ What do you need to do with the file? │
└─────────────────────────────────────────────────────┘
┌─────────────┴─────────────┐
▼ ▼
Process in a block Store for user access
(ffmpeg, PIL, etc.) (CoPilot files, uploads)
│ │
▼ ▼
store_media_file() WorkspaceManager
with appropriate
return_format
┌──────┴──────┐
▼ ▼
"for_local_ "for_block_
processing" output"
│ │
▼ ▼
Get local Auto-saves to
path for workspace in
tools CoPilot context
Store for user access
├── write_file() ─── Upload + persist (scans internally)
├── read_file() / get_download_url() ─── Retrieve
└── list_files() / delete_file() ─── Manage
```
### Quick Reference
| Scenario | Use |
|----------|-----|
| Block needs to process a file with ffmpeg | `store_media_file(..., return_format="for_local_processing")` |
| Block needs to send file to external API | `store_media_file(..., return_format="for_external_api")` |
| Block returning a generated file | `store_media_file(..., return_format="for_block_output")` |
| API endpoint handling file upload | `WorkspaceManager.write_file()` (handles virus scanning internally) |
| API endpoint serving file download | `WorkspaceManager.get_download_url()` |
| Listing user's files | `WorkspaceManager.list_files()` |
---
## Key Files Reference
| File | Purpose |
|------|---------|
| `backend/data/workspace.py` | Database CRUD operations for UserWorkspace and UserWorkspaceFile |
| `backend/util/workspace.py` | `WorkspaceManager` class - high-level workspace API |
| `backend/util/workspace_storage.py` | Storage backends (GCS, local) and `WorkspaceStorageBackend` interface |
| `backend/util/file.py` | `store_media_file()` and media processing utilities |
| `backend/util/virus_scanner.py` | `VirusScannerService` and `scan_content_safe()` |
| `schema.prisma` | Database model definitions |
---
## Common Patterns
### Block Processing a User's File
```python
async def run(self, input_data, *, execution_context, **kwargs):
# Normalize input to local path
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# Process with local tools
output_path = process_video(local_path)
# Return (auto-saves to workspace in CoPilot)
result = await store_media_file(
file=output_path,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output", result
```
### API Upload Endpoint
```python
from backend.util.virus_scanner import VirusDetectedError, VirusScanError
async def upload_file(file: UploadFile, user_id: str, workspace_id: str):
content = await file.read()
# write_file handles virus scanning internally
manager = WorkspaceManager(user_id, workspace_id)
try:
workspace_file = await manager.write_file(
content=content,
filename=file.filename,
)
except VirusDetectedError:
raise HTTPException(status_code=400, detail="File rejected: virus detected")
except VirusScanError:
raise HTTPException(status_code=503, detail="Virus scanning unavailable")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"file_id": workspace_file.id}
```
---
## Configuration
| Setting | Purpose | Default |
|---------|---------|---------|
| `media_gcs_bucket_name` | GCS bucket for workspace storage | None (uses local) |
| `workspace_storage_dir` | Local storage directory | `{app_data}/workspaces` |
| `max_file_size_mb` | Maximum file size in MB | 100 |
| `clamav_service_enabled` | Enable virus scanning | true |
| `clamav_service_host` | ClamAV daemon host | localhost |
| `clamav_service_port` | ClamAV daemon port | 3310 |
| `clamav_max_concurrency` | Max concurrent scans to ClamAV daemon | 5 |
| `clamav_mark_failed_scans_as_clean` | If true, scan failures pass content through instead of rejecting (⚠️ security risk if ClamAV is unreachable) | false |