Compare commits

..

72 Commits

Author SHA1 Message Date
Nicholas Tindle
b3f35953ed feat(classic): add interactive config command to CLI
Add a new `config` command that opens a tabbed TUI for browsing and
editing AutoGPT settings. The UI allows users to configure settings
interactively rather than manually editing .env files.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 18:04:53 -06:00
Nicholas Tindle
d8d87f2853 Merge branch 'dev' into make-old-work 2026-01-29 19:32:34 -06:00
Nicholas Tindle
791e1d8982 fix(classic): resolve CI lint, type, and test failures
- Fix line-too-long in test_permissions.py docstring
- Fix type annotation in validators.py (callable -> Callable)
- Add --fresh flag to benchmark tests to prevent state resumption
- Exclude direct_benchmark/adapters from pyright (optional deps)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 14:31:11 -06:00
Nicholas Tindle
0040636948 fix(permissions): update wildcard handling for command patterns 2026-01-26 12:42:21 -06:00
Nicholas Tindle
c671af851f feat(classic): add platform_blocks to Agent, enable via PLATFORM_API_KEY
- Add PlatformBlocksComponent to Agent as a default component
- Component automatically enables when PLATFORM_API_KEY env var is set
- Config now uses UserConfigurable for env var support:
  - PLATFORM_API_KEY (required to enable)
  - PLATFORM_URL (default: https://platform.agpt.co)
  - PLATFORM_BLOCKS_ENABLED (default: true)
  - PLATFORM_TIMEOUT (default: 60)
- API key stored as SecretStr for security

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 17:30:24 -06:00
Nicholas Tindle
7dd181f4b0 feat(classic): make CWD the default agent workspace for CLI mode
In CLI mode, agents now work directly in the current directory instead of
being sandboxed to .autogpt/agents/{id}/workspace/. Agent state files are
still stored in .autogpt/agents/{id}/state.json.

Server mode retains the original sandboxed behavior for isolation.

Changes:
- Add workspace_root parameter to FileManagerComponent to detect CLI mode
- Update Agent to pass workspace_root when file_storage is rooted at workspace
- Adjust save_state paths based on mode (CLI uses .autogpt/ prefix)
- Add use_tools field to ActionProposal for parallel tool execution
- Support parallel tool execution in Agent.execute()

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 15:20:12 -06:00
Nicholas Tindle
114856cef1 refactor(classic): improve prompt strategies with both general and code-specific guidance
- SystemComponent: Keep both general constraints (physical objects) and
  code-specific constraints (don't modify tests, check dependencies, no secrets)
- SystemComponent: Keep both general best practices (self-review, reflection)
  and code-specific best practices (read before modify, mimic style, verify)
- LATS: Keep general phase instructions while adding coding task priorities
- one_shot: Remove redundant 'text' field from AssistantThoughts, use 'reasoning'
- one_shot: Fix intro to clarify when to use ask_user instead of contradicting it
- one_shot: Add efficiency guidelines and parallel execution support
- Update UI to display reasoning as main thoughts (remove redundant REASONING line)
- Update test fixtures to match new AssistantThoughts schema

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:27:32 -06:00
Nicholas Tindle
68b9bd0c51 refactor(classic): use platform API for blocks instead of local loading
Simplify the platform_blocks component to fetch blocks from the
platform API (/api/v1/blocks) instead of loading them locally from
the monorepo. This removes the dependency on having the platform
backend code available.

- Remove loader.py (no longer needed)
- Update client.py with list_blocks() method
- Simplify component.py to use API for both search and execute
- Remove user_id from config (not needed by API)
- Update tests for API-based approach

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:16:39 -06:00
Nicholas Tindle
ff076b1f15 feat(classic): add platform blocks component for classic agents
Add search_blocks and execute_block commands that expose platform blocks
to classic agents:

- search_blocks: Local search by name, description, or category (fast, offline)
- execute_block: Execute via platform API with automatic credential handling

The loader automatically discovers the platform backend from the monorepo
structure without requiring manual PYTHONPATH configuration.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 13:10:57 -06:00
Nicholas Tindle
57fbab500b feat(classic): add external benchmark adapters for GAIA, SWE-bench, and AgentBench
Integrate standard AI agent benchmarks into the direct_benchmark infrastructure
using a plugin-based adapter pattern:

- Add BenchmarkAdapter base class with setup(), load_challenges(), and evaluate()
- Implement GAIAAdapter for the GAIA benchmark (requires HF token)
- Implement SWEBenchAdapter for SWE-bench (requires Docker)
- Implement AgentBenchAdapter for AgentBench multi-environment benchmark
- Extend HarnessConfig with benchmark options (--benchmark, --benchmark-split, etc.)
- Modify ParallelExecutor to use adapter's evaluate() for external benchmarks
- Fix runner to record finish step (was being skipped, breaking answer extraction)
- Add optional benchmarks dependency group with datasets and huggingface-hub
- Increase default benchmark timeout to 900s

Usage:
  poetry run direct-benchmark run \
    --benchmark agent-bench \
    --benchmark-subset dbbench \
    --strategies one_shot \
    --models claude

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 13:06:32 -06:00
Nicholas Tindle
6faabef24d fix(classic): always recreate Docker containers for code execution
Docker containers cannot have their mount bindings updated after creation.
When running benchmarks or multiple agent instances, the same container name
could be reused with a different workspace directory, causing the container
to still reference the OLD mount path. This resulted in "python: can't open
file '/workspace/temp*.py'" errors.

The fix: remove existing containers before creating new ones to ensure fresh
mount bindings to the current workspace directory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 23:57:02 -06:00
Nicholas Tindle
a67d475a69 fix(classic): handle parallel tool calls in action history
When prompts encourage parallel tool execution and the LLM makes multiple
tool calls simultaneously, the Anthropic API requires a tool_result message
for EACH tool_use. Previously, we only created one tool result for the first
tool call, causing "tool_use ids were found without tool_result blocks" errors.

This fix:
- Adds _make_result_messages() to create results for ALL tool calls
- Maps tool names to their outputs from parallel execution results
- Handles errors per-tool from the _errors list
- Falls back gracefully when results are missing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 23:18:15 -06:00
Nicholas Tindle
326554d89a style(classic): update black to 24.10.0 and reformat
Update black version to match pre-commit hook (24.10.0) and reformat
all files with the new version.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 10:51:54 -06:00
Nicholas Tindle
5e22a1888a chore: add classic benchmark reports and workspaces to gitignore
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 10:42:55 -06:00
Nicholas Tindle
a4d7b0142f fix(classic): resolve all pyright type errors
- Add missing strategies (lats, multi_agent_debate) to PromptStrategyName
- Fix method override signatures for reasoning_effort parameter
- Fix Pydantic Field() overload issues with helper function
- Fix BeautifulSoup Tag type narrowing in web_fetch.py
- Fix Optional member access in playwright_browser.py and rewoo.py
- Convert hasattr patterns to getattr for proper type narrowing
- Add proper type casts for Literal types
- Fix file storage path type conversions
- Exclude legacy challenges/ from pyright checking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 10:41:53 -06:00
Nicholas Tindle
7d6375f59c style(classic): fix flake8 line length issue
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 01:25:00 -06:00
Nicholas Tindle
aeec0ce509 chore: add test.db to gitignore 2026-01-20 01:24:22 -06:00
Nicholas Tindle
b32bfcaac5 chore: remove test.db from tracking 2026-01-20 01:24:00 -06:00
Nicholas Tindle
5373a6eb6e style(classic): fix code formatting with black
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 01:23:51 -06:00
Nicholas Tindle
98cde46ccb style(classic): fix import sorting with isort
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 01:23:33 -06:00
Nicholas Tindle
bd10da10d9 ci: update pre-commit hooks for consolidated classic Poetry project
- Consolidate classic poetry-install hooks into single hook using classic/
- Update isort hook to work with consolidated project structure
- Simplify flake8 hooks to use single classic/.flake8 config
- Consolidate pyright hooks into single hook for classic/
- Add direct_benchmark to hook coverage

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 01:21:50 -06:00
Nicholas Tindle
60fdee1345 fix(classic): resolve linting and formatting issues for CI compliance
- Update .flake8 config to exclude workspace directories and ignore E203
- Fix import sorting (isort) across multiple files
- Fix code formatting (black) across multiple files
- Remove unused imports and fix line length issues (flake8)
- Fix f-strings without placeholders and unused variables

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 01:16:38 -06:00
Nicholas Tindle
6f2783468c feat(classic): add sub-agent architecture and LATS/multi-agent debate strategies
Add comprehensive sub-agent spawning infrastructure that enables prompt
strategies to coordinate multiple agents for advanced reasoning patterns.

New files:
- forge/agent/execution_context.py: ExecutionContext, ResourceBudget,
  SubAgentHandle, and AgentFactory protocol for sub-agent lifecycle
- agent_factory/default_factory.py: DefaultAgentFactory implementation
- prompt_strategies/lats.py: Language Agent Tree Search using MCTS
  with sub-agents for action expansion and evaluation
- prompt_strategies/multi_agent_debate.py: Multi-agent debate with
  proposal, critique, and consensus phases

Key changes:
- BaseMultiStepPromptStrategy gains spawn_sub_agent(), run_sub_agent(),
  spawn_and_run(), and run_parallel() methods
- Agent class accepts optional ExecutionContext and injects it into strategies
- Sub-agents enabled by default (enable_sub_agents=True)
- Resource limits: max_depth=5, max_sub_agents=25, max_cycles=25

All 7 strategies now available in benchmark:
one_shot, rewoo, plan_execute, reflexion, tree_of_thoughts, lats, multi_agent_debate

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 01:01:28 -06:00
Nicholas Tindle
c1031b286d ci(classic): update CI workflows for consolidated Poetry project
Update all classic CI workflows to use the single consolidated
pyproject.toml at classic/ instead of individual project directories.

Changes:
- classic-autogpt-ci.yml: Run from classic/, update cache key and test paths
- classic-forge-ci.yml: Run from classic/, update cache key and test paths
- classic-benchmark-ci.yml: Run from classic/, use direct-benchmark command
- classic-python-checks.yml: Simplify to single job (no matrix needed)
- classic-autogpts-ci.yml: Update to use direct-benchmark for smoke tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:53:50 -06:00
Nicholas Tindle
b849eafb7f feat(direct_benchmark): enable shell command execution with safety denylist
Enable agents to execute shell commands during benchmarks by setting
execute_local_commands=True and using denylist mode to block dangerous
commands (rm, sudo, chmod, kill, etc.) while allowing safe operations.

Also adds ExecutePython challenge to test code execution capability.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:52:06 -06:00
Nicholas Tindle
572c3f5e0d refactor(classic): consolidate Poetry projects into single pyproject.toml
Merge forge/, original_autogpt/, and direct_benchmark/ into a single Poetry
project to eliminate cross-project path dependency issues.

Changes:
- Create classic/pyproject.toml with merged dependencies from all three projects
- Remove individual pyproject.toml and poetry.lock files from subdirectories
- Update all CLAUDE.md files to reflect commands run from classic/ root
- Update all README.md files with new installation and usage instructions

All packages are now included via the packages directive:
- forge/forge (core agent framework)
- original_autogpt/autogpt (AutoGPT agent)
- direct_benchmark/direct_benchmark (benchmark harness)

CLI entry points preserved: autogpt, serve, direct-benchmark

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:49:56 -06:00
Nicholas Tindle
89003a585d feat(direct_benchmark): show "would have passed" for timed-out challenges
When a challenge times out but the agent's solution would have passed
evaluation, this is now clearly indicated:

- Completion blocks show "TIMEOUT (would have passed)" in yellow
- Recent completions panel shows hourglass icon + "would pass" suffix
- Summary table has new "Would Pass" column
- Final summary shows "+N would pass" count
- Success rate includes "would pass" challenges

The evaluator still runs on timed-out challenges to calculate the score,
but success remains False. This gives visibility into near-misses that
just needed more time.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:30:00 -06:00
Nicholas Tindle
0e65785228 fix(direct_benchmark): don't mark timed-out challenges as passed
Previously, the evaluator would run on all results including timed-out
challenges. If the agent happened to write a working solution before
timing out, evaluation would pass and override success=True, resulting
in contradictory output showing both PASS and "timed out".

Now we skip evaluation for timed-out challenges - they cannot pass.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:25:41 -06:00
Nicholas Tindle
f07dff1cdd fix(direct_benchmark): add pytest dependency for challenge evaluation
The TicTacToe and other challenges use pytest-based test files for
evaluation. Without pytest installed in the benchmark virtualenv,
these evaluations were silently failing.

Root cause: test.py imports pytest but the package wasn't a dependency,
causing ModuleNotFoundError during evaluation subprocess.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:21:12 -06:00
Nicholas Tindle
00e02a4696 feat(direct_benchmark): add run ID to completion blocks
Include config:challenge:attempt and timestamp in completion block
header for easier debugging and log correlation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:14:23 -06:00
Nicholas Tindle
634bff8277 refactor(forge): replace Selenium with Playwright for web browsing
- Remove selenium.py and test_selenium.py
- Add playwright_browser.py with WebPlaywrightComponent
- Update web component exports to use Playwright
- Update dependencies in pyproject.toml/poetry.lock
- Minor agent and reflexion strategy improvements
- Update CLAUDE.md documentation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:57:17 -06:00
Nicholas Tindle
d591f36c7b fix(direct_benchmark): track cost from LLM provider
Previously cost was hardcoded to 0.0. Now extracts cumulative cost
from MultiProvider.get_incurred_cost() after each step execution.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:37:12 -06:00
Nicholas Tindle
a347bed0b1 feat(direct_benchmark): add incremental resume and selective reset
Benchmarks now automatically save progress and resume from where they
left off. State is persisted to .benchmark_state.json in reports dir.

Features:
- Auto-resume: runs skip already-completed challenges
- --fresh: clear all state and start over
- --retry-failures: re-run only failed challenges
- --reset-strategy/model/challenge: selective resets
- `state show/clear/reset` subcommands for state management
- Config mismatch detection with auto-reset

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:32:27 -06:00
Nicholas Tindle
4eeb6ee2b0 feat(direct_benchmark): add CI mode for non-interactive environments
Add --ci flag that disables Rich Live display while preserving
completion blocks. Auto-detects CI environment via CI env var or
non-TTY stdout. Prints progress every 10 completions for visibility.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:21:10 -06:00
Nicholas Tindle
7db962b9f9 feat(direct_benchmark): dynamic column layout up to 10 wide
- Calculate max columns based on terminal width (up to 10)
- Reduced panel width from 35 to 30 chars to fit more
- Wider terminals can now show more parallel runs side-by-side

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:15:16 -06:00
Nicholas Tindle
9108b21541 fix(direct_benchmark): parallel execution and always show completion blocks
Fixes:
- Use run_key (config:challenge) instead of just config_name for tracking
  active runs - allows multiple challenges from same config to run in parallel
- Add asyncio.sleep(0) yields to let multiple tasks acquire semaphore
  and start before any proceed with work
- Always print completion blocks (not just failures) for visibility

This should properly show 8/8 active runs when running with --parallel 8.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:13:56 -06:00
Nicholas Tindle
ffe9325296 feat(direct_benchmark): multi-panel UI with copy-paste completion blocks
UI improvements:
- Multi-column layout: each active config gets its own panel showing
  challenge name and step history (last 6 steps with status)
- Copy-paste completion blocks: when a challenge finishes (especially
  failures), prints a detailed block with all steps for easy debugging
- Configurable logging: suppresses noisy LLM provider warnings unless
  --debug flag is set
- Pass debug flag through harness to UI

Example active runs panel:
┌─ one_shot/claude ─┬─ rewoo/claude ────┐
│ ReadFile          │ WriteFile         │
│   ✓ #1 read_file  │   ✓ #1 think      │
│   ✓ #2 write_file │   ✓ #2 plan       │
│   ● step 3: ...   │   ● step 3: ...   │
└───────────────────┴───────────────────┘

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:10:34 -06:00
Nicholas Tindle
0a616d9267 feat(direct_benchmark): add step-level logging with colored prefixes
- Add step callback to AgentRunner for real-time step logging
- BenchmarkUI now shows:
  - Active runs with current step info
  - Recent steps panel with colored config prefixes
  - Proper Live display refresh (implements __rich_console__)
- Each config gets a distinct color for easy identification
- Verbose mode prints step logs immediately with config prefix
- Fix Live display not updating (pass UI object, not rendered content)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 23:02:20 -06:00
Nicholas Tindle
ab95077e5b refactor(forge): remove VCR cassettes, use real API calls with skip for forks
- Remove vcrpy and pytest-recording dependencies
- Remove tests/vcr/ directory and vcr_cassettes submodule
- Remove .gitmodules (only had cassette submodule)
- Simplify CI workflow - no more cassette checkout/push/PAT_REVIEW
- Tests requiring API keys now skip if not set (fork PRs)
- Update CLAUDE.md files to remove cassette references
- Fix broken agbenchmark path in pyproject.toml

Security improvement: removes need for PAT with cross-repo write access.
Fork PRs will have API-dependent tests skipped (GitHub protects secrets).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 22:51:57 -06:00
Nicholas Tindle
e477150979 Merge branch 'dev' into make-old-work 2026-01-19 22:30:46 -06:00
Nicholas Tindle
804430e243 refactor(classic): migrate from agbenchmark to direct_benchmark harness
- Remove old benchmark/ folder with agbenchmark framework
- Move challenges to direct_benchmark/challenges/
- Move analysis tools (analyze_reports.py, analyze_failures.py) to direct_benchmark/
- Move challenges_already_beaten.json to direct_benchmark/
- Update CI workflow to use direct_benchmark
- Update CLAUDE.md files with new benchmarking instructions
- Add benchmarking section to original_autogpt/CLAUDE.md

The direct_benchmark harness directly instantiates agents without HTTP
server overhead, enabling parallel execution with asyncio semaphore.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 22:29:51 -06:00
Nicholas Tindle
acb320d32d feat(classic): add noninteractive mode env var and benchmark config logging
- Add NONINTERACTIVE_MODE env var support to AppConfig for disabling
  user interaction during automated runs
- Benchmark harness now sets NONINTERACTIVE_MODE=True when starting agents
- Add agent configuration logging at server startup (model, strategy, etc.)
- Harness logs env vars being passed to agent for verification
- Add --agent-output flag to show full agent server output for debugging

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:40:24 -06:00
Nicholas Tindle
32f68d5999 feat(classic): add failure analysis tool and improve benchmark output
Benchmark improvements:
- Add analyze_failures.py for pattern detection and failure analysis
- Add informative step output: tool name, args, result status, cost
- Add --all and --matrix flags for comprehensive model/strategy testing
- Add --analyze-only and --no-analyze flags for flexible analysis control
- Auto-run failure analysis after benchmarks with markdown export
- Fix directory creation bug in ReportManager (add parents=True)

Prompt strategy enhancements:
- Implement full plan_execute, reflexion, rewoo, tree_of_thoughts strategies
- Add PROMPT_STRATEGY env var support for strategy selection
- Add extended thinking support for Anthropic models
- Add reasoning effort support for OpenAI o-series models

LLM provider improvements:
- Add thinking_budget_tokens config for Anthropic extended thinking
- Add reasoning_effort config for OpenAI reasoning models
- Improve error feedback for LLM self-correction

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 18:58:41 -06:00
Nicholas Tindle
49f56b4e8d feat(classic): enhance strategy benchmark harness with model comparison and bug fixes
- Add model comparison support to test harness (claude, openai, gpt5, opus presets)
- Add --models, --smart-llm, --fast-llm, --list-models CLI args
- Add real-time logging with timestamps and progress indicators
- Fix success parsing bug: read results[0].success instead of non-existent metrics.success
- Fix agbenchmark TestResult validation: use exception typename when value is empty
- Fix WebArena challenge validation: use strings instead of integers in instantiation_dict
- Fix Agent type annotations: create AnyActionProposal union for all prompt strategies
- Add pytest integration tests for the strategy benchmark harness

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 18:07:14 -06:00
Nicholas Tindle
bead811e73 docs(classic): add workspace, settings, and permissions documentation
Document the layered configuration system including:
- Workspace structure (.autogpt/ directory layout)
- Settings location (environment variables, workspace YAML, agent YAML)
- Permission system (check order, pattern syntax, approval scopes)
- Default security behavior

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 12:17:10 -06:00
Nicholas Tindle
013f728ebf feat(forge): improve tool call error feedback for LLM self-correction
When tool calls fail validation, the error messages now include:
- What arguments were actually provided
- The expected parameter schema with types and required/optional indicators

This helps LLMs understand and fix their mistakes when retrying,
rather than just being told a parameter is missing.

Example improved error:
  Invalid function call for write_file: 'contents' is a required property
  You provided: {"filename": 'story.txt'}
  Expected parameters: {"filename": string (required), "contents": string (required)}

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 11:49:17 -06:00
Nicholas Tindle
cda9572acd feat(forge): add lightweight web fetch component
Add WebFetchComponent for fast HTTP-based page fetching without browser
overhead. Uses trafilatura for intelligent content extraction.

Commands:
- fetch_webpage: Extract main content as text/markdown/xml
  - Removes navigation, ads, boilerplate automatically
  - Extracts page metadata (title, description, author, date)
  - Extracts and lists page links
  - Much faster than Selenium-based read_webpage

- fetch_raw_html: Get raw HTML for structure inspection
  - Optional truncation for large pages

Features:
- Trafilatura-powered content extraction (best-in-class accuracy)
- Automatic link extraction with relative URL resolution
- Page metadata extraction (OG tags, meta tags)
- Configurable timeout, max content length, max links
- Proper error handling for timeouts and HTTP errors
- 19 comprehensive tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 01:04:22 -06:00
Nicholas Tindle
e0784f8f6b refactor(forge): simplify deeply nested error handling in Anthropic provider
- Extract _get_tool_error_message helper method
- Replace 20+ levels of nesting with simple for loop
- Improve readability of tool_result construction
- Update benchmark poetry.lock

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 00:15:33 -06:00
Nicholas Tindle
3040f39136 feat(forge): modernize web search with tiered provider system
Replace basic DuckDuckGo-only search with a modern tiered system:

1. Tavily (primary) - AI-optimized results with content extraction
   - AI-generated answer summaries
   - Relevance scoring
   - Full page content extraction via search_and_extract command

2. Serper (secondary) - Fast, cheap Google SERP results
   - $0.30-1.00 per 1K queries
   - Real Google results without scraping

3. DDGS multi-engine (fallback) - Free, no API key required
   - Automatic fallback chain: DuckDuckGo → Bing → Brave → Google → etc.
   - 8 search backends supported

Key changes:
- Upgrade duckduckgo-search to ddgs v9.10 (renamed successor package)
- Add Tavily and Serper API integrations
- Implement automatic provider selection and fallback chain
- Add search_and_extract command for research with content extraction
- Add TAVILY_API_KEY and SERPER_API_KEY to env templates
- Update benchmark httpx constraint for ddgs compatibility
- 23 comprehensive tests for all providers and fallback scenarios

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 00:06:42 -06:00
Nicholas Tindle
515504c604 fix(classic): resolve pyright type errors in original_autogpt
- Change Agent class to use ActionProposal instead of OneShotAgentActionProposal
  to support multiple prompt strategy types
- Widen display_thoughts parameter type from AssistantThoughts to ModelWithSummary
- Fix speak attribute access in agent_protocol_server with hasattr check
- Add type: ignore comments for intentional thoughts field overrides in strategies
- Remove unused OneShotAgentActionProposal import

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 23:53:23 -06:00
Nicholas Tindle
18edeaeaf4 fix(classic): fix linting and formatting errors across codebase
- Fix 32+ flake8 E501 (line too long) errors by shortening descriptions
- Remove unused import in todo.py
- Fix test_todo.py argument order (config= keyword)
- Add type annotations to fix pyright errors where straightforward
- Add noqa comments for flake8 false positives in __init__.py
- Remove unused nonlocal declarations in main.py
- Run black and isort to fix formatting
- Update CLAUDE.md with improved linting commands

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 23:37:28 -06:00
Nicholas Tindle
44182aff9c feat(classic): add strategy benchmark test harness for CI
- Add test_prompt_strategies.py harness to compare prompt strategies
- Add pytest wrapper (test_strategy_benchmark.py) for CI integration
- Fix serve command (remove invalid --port flag, use AP_SERVER_PORT env)
- Fix test category (interface -> general)
- Add aiohttp-retry dependency for agbenchmark
- Add pytest markers: slow, integration, requires_agent

Usage:
  poetry run python agbenchmark_config/test_prompt_strategies.py --quick
  poetry run pytest tests/integration/test_strategy_benchmark.py -v

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 23:36:19 -06:00
Nicholas Tindle
864c5a7846 fix(classic): approve+feedback now executes command then sends feedback
Previously, when a user selected "Once" or "Always" with feedback (via Tab),
the command was NOT executed because UserFeedbackProvided was raised before
checking the approval scope. This fix changes the architecture from
exception-based to return-value-based.

Changes:
- Add PermissionCheckResult class with allowed, scope, and feedback fields
- Change check_command() to return PermissionCheckResult instead of bool
- Update prompt_fn signature to return (ApprovalScope, feedback) tuple
- Add pending_user_feedback mechanism to EpisodicActionHistory
- Update execute() to handle feedback after successful command execution
- Feedback message explicitly states "Command executed successfully"
- Add on_auto_approve callback for displaying auto-approved commands
- Add comprehensive tests for approval/denial with feedback scenarios

Behavior:
- Once + feedback → Execute command, then send feedback to agent
- Always + feedback → Execute command, save permission, send feedback
- Deny + feedback → Don't execute, send feedback to agent

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 22:32:43 -06:00
Nicholas Tindle
699fffb1a8 feat(classic): add Rich interactive selector for command approval
Adds a custom Rich-based interactive selector for the command approval
workflow. Features include:
- Arrow key navigation for selecting approval options
- Tab to add context to any selection (e.g., "Once + also check file x")
- Dedicated inline feedback option with shadow placeholder text
- Quick select with number keys 1-5
- Works within existing asyncio event loop (no prompt_toolkit dependency)

Also adds UIProvider abstraction pattern for future UI implementations.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 21:49:43 -06:00
Nicholas Tindle
f0641c2d26 fix(classic): auto-advance plan steps in Plan-Execute strategy
The strategy was stuck in a loop because it tracked plan steps but never
advanced them - the record_step_success() method existed but was never
called by the agent's execution loop.

Fix by using a _pending_step_advance flag to track when an action has
been proposed. On the next parse_response_content() call, advance the
previous step before processing the new response. This keeps step
tracking self-contained in the strategy without requiring agent changes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 21:14:16 -06:00
Nicholas Tindle
94b6f74c95 feat(classic): add multiple prompt strategies for agent reasoning
Implement four new prompt strategies based on research papers:

- ReWOO: Reasoning Without Observation (5x token efficiency)
- Plan-and-Execute: Separate planning from execution phases
- Reflexion: Verbal reinforcement learning with episodic memory
- Tree of Thoughts: Deliberate problem solving with tree search

Each strategy extends a new BaseMultiStepPromptStrategy base class
with shared utilities. Strategies are selectable via PROMPT_STRATEGY
environment variable or config.prompt_strategy setting.

Fix JSONSchema generation issue where Optional/Union types created
anyOf schemas without direct type field - resolved by storing
plan/phase state in strategy instances rather than ActionProposal.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 20:33:10 -06:00
Nicholas Tindle
46aabab3ea feat(classic): upgrade to Python 3.12+ with CI testing on 3.12, 3.13, 3.14
- Update Python version constraint from ^3.10 to ^3.12 in all pyproject.toml
- Update classifiers to reflect Python 3.12, 3.13, 3.14 support
- Update dependencies for Python 3.13+ compatibility:
  - chromadb: ^0.4.10 -> ^1.4.0
  - numpy: >=1.26.0,<2.0.0 -> >=2.0.0
  - watchdog: 4.0.0 -> ^6.0.0
  - spacy: ^3.0.0 -> ^3.8.0 (numpy 2.x compatibility)
  - en-core-web-sm model: 3.7.1 -> 3.8.0
  - httpx (benchmark): ^0.24.0 -> ^0.27.0
- Update tool configuration:
  - Black target-version: py310 -> py312
  - Pyright pythonVersion: 3.10 -> 3.12
- Update Dockerfiles to use Python 3.12
- Update CI workflows to test on Python 3.12, 3.13, and 3.14
- Regenerate all poetry.lock files

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 20:25:11 -06:00
Nicholas Tindle
0a65df5102 fix(classic): always use native tool calling, fix N/A command loop
- Remove openai_functions config option - native tool calling is now always enabled
- Remove use_functions_api from BaseAgentConfiguration and prompt strategy
- Add use_prefill config to disable prefill for Anthropic (prefill + tools incompatible)
- Update anthropic dependency to ^0.45.0 for tools API support
- Simplify prompt strategy to always expect tool_calls from LLM response

This fixes the N/A command loop bug where models would output "N/A" as a
command name when function calling was disabled. With native tool calling
always enabled, models are forced to pick from valid tools only.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 19:54:40 -06:00
Nicholas Tindle
6fbd208fe3 chore: ignore .claude/settings.local.json in all directories
Update gitignore to use glob pattern for settings.local.json files
in any .claude directory. Also untrack the existing file.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 18:54:42 -06:00
Nicholas Tindle
8fc174ca87 refactor(classic): simplify log format by removing timestamps
Remove asctime from log formats since terminal output already has
timestamps from the logging infrastructure. Makes logs cleaner
and easier to read.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 18:52:47 -06:00
Nicholas Tindle
cacc89790f feat(classic): improve AutoGPT configuration and setup
Environment loading:
- Search for .env in multiple locations (cwd, ~/.autogpt, ~/.config/autogpt)
- Allows running autogpt from any directory
- Document search order in .env.template

Setup simplification:
- Remove interactive AI settings revision (was broken/unused)
- Simplify to just printing current settings
- Clean up unused imports

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 18:52:38 -06:00
Nicholas Tindle
b9113bee02 feat(classic): enhance existing components with new capabilities
CodeExecutorComponent:
- Add timeout and env_vars parameters to execution commands
- Add execute_shell_popen for streaming output
- Improve error handling with CodeTimeoutError

FileManagerComponent:
- Add file_info, file_search, file_copy, file_move commands
- Add directory_create, directory_list_tree commands
- Better path validation and error messages

GitOperationsComponent:
- Add git_log, git_show, git_branch commands
- Add git_stash, git_stash_pop, git_stash_list commands
- Add git_cherry_pick, git_revert, git_reset commands
- Add git_remote, git_fetch, git_pull, git_push commands

UserInteractionComponent:
- Add ask_multiple_choice for structured options
- Add notify_user for non-blocking notifications
- Add confirm_action for yes/no confirmations

WebSearchComponent:
- Minor error handling improvements

WebSeleniumComponent:
- Add get_page_content, execute_javascript commands
- Add take_element_screenshot command
- Add wait_for_element, scroll_page commands
- Improve element interaction reliability

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 18:52:27 -06:00
Nicholas Tindle
3f65da03e7 feat(classic): add new exception types for enhanced error handling
Add specialized exception classes for better error reporting:
- CodeTimeoutError: For code execution timeouts
- HTTPError: For HTTP request failures with status code/URL
- DataProcessingError: For JSON/CSV processing errors

Each exception includes helpful hints for users.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 18:52:10 -06:00
Nicholas Tindle
9e96d11b2d feat(classic): add utility components for agent capabilities
Add 6 new utility components to expand agent functionality:

- ArchiveHandlerComponent: ZIP/TAR archive operations (create, extract, list)
- ClipboardComponent: In-memory clipboard for copy/paste operations
- DataProcessorComponent: CSV/JSON data manipulation and analysis
- HTTPClientComponent: HTTP requests (GET, POST, PUT, DELETE)
- MathUtilsComponent: Mathematical calculations and statistics
- TextUtilsComponent: Text processing (regex, diff, encoding, hashing)

All components follow the forge component pattern with:
- CommandProvider for exposing commands
- DirectiveProvider for resources/best practices
- Comprehensive parameter validation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 18:50:52 -06:00
Nicholas Tindle
4c264b7ae9 feat(classic): add TodoComponent with LLM-powered decomposition
Add a task management component modeled after Claude Code's TodoWrite:
- TodoItem with recursive sub_items for hierarchical task structure
- todo_write: atomic list replacement with sub-items support
- todo_read: retrieve current todos with nested structure
- todo_clear: clear all todos
- todo_decompose: use smart LLM to break down tasks into sub-steps

Features:
- Hierarchical task tracking with independent status per sub-item
- MessageProvider shows todos in LLM context with proper indentation
- DirectiveProvider adds best practices for task management
- Graceful fallback when LLM provider not configured

Integrates with:
- original_autogpt Agent (full LLM decomposition support)
- ForgeAgent (basic task tracking, no decomposition)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 18:49:48 -06:00
Nicholas Tindle
0adbc0bd05 fix(classic): update CI for removed frontend and helper scripts
Remove references to deleted files (./run, cli.py, setup.py, frontend/)
from CI workflows. Replace ./run agent start with direct poetry commands
to start agent servers in background.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 17:41:11 -06:00
Nicholas Tindle
8f3291bc92 feat(classic): add workspace permissions system for agent commands
Add a layered permission system that controls agent command execution:

- Create autogpt.yaml in .autogpt/ folder with default allow/deny rules
- File operations in workspace allowed by default
- Sensitive files (.env, .key, .pem) blocked by default
- Dangerous shell commands (sudo, rm -rf) blocked by default
- Interactive prompts for unknown commands (y=agent, Y=workspace, n=deny)
- Agent-specific permissions stored in .autogpt/agents/{id}/permissions.yaml

Files added:
- forge/forge/config/workspace_settings.py - Pydantic models for settings
- forge/forge/permissions.py - CommandPermissionManager with pattern matching

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 17:39:33 -06:00
Nicholas Tindle
7a20de880d chore: add .autogpt/ to gitignore
The .autogpt/ directory is where AutoGPT stores agent data when running
from any directory. This should not be committed to version control.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 17:02:47 -06:00
Nicholas Tindle
ef8a6d2528 feat(classic): make AutoGPT installable and runnable from any directory
Add --workspace option to CLI that defaults to current working directory,
allowing users to run `autogpt` from any folder. Agent data is now stored
in `.autogpt/` subdirectory of the workspace instead of a hardcoded path.

Changes:
- Add -w/--workspace CLI option to run and serve commands
- Remove dependency on forge package location for PROJECT_ROOT
- Update config to use workspace instead of project_root
- Store agent data in .autogpt/ within workspace directory
- Update pyproject.toml files with proper PyPI metadata
- Fix outdated tests to match current implementation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 17:00:36 -06:00
Nicholas Tindle
fd66be2aaa chore(classic): remove unneeded files and add CLAUDE.md docs
- Remove deprecated Flutter frontend (replaced by autogpt_platform)
- Remove shell scripts (run, setup, autogpt.sh, etc.)
- Remove tutorials (outdated)
- Remove CLI-USAGE.md and FORGE-QUICKSTART.md
- Add CLAUDE.md files for Claude Code guidance

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 16:17:35 -06:00
Nicholas Tindle
ae2cc97dc4 feat(classic): add modern Anthropic models and fix deprecated API
- Add Claude 3.5 v2, Claude 4 Sonnet, Claude 4 Opus, and Claude 4.5 Opus models
- Add rolling aliases (CLAUDE_SONNET, CLAUDE_OPUS, CLAUDE_HAIKU)
- Fix deprecated beta.tools.messages.create API call to use standard messages.create
- Update anthropic SDK from ^0.25.1 to >=0.40,<1.0

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 16:15:16 -06:00
Nicholas Tindle
ea521eed26 wip: add supprot for new openai models (non working) 2025-12-26 10:02:17 -06:00
2363 changed files with 39970 additions and 826920 deletions

View File

@@ -6,11 +6,15 @@ on:
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
concurrency:
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -19,47 +23,22 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic/original_autogpt
working-directory: classic
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
runs-on: ubuntu-latest
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
- name: Start MinIO service
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
@@ -71,41 +50,23 @@ jobs:
git config --global user.name "Auto-GPT-Bot"
git config --global user.email "github-bot@agpt.co"
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.12"
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Python dependencies
run: poetry install
@@ -116,12 +77,12 @@ jobs:
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests/unit tests/integration
original_autogpt/tests/unit original_autogpt/tests/integration
env:
CI: true
PLAIN_OUTPUT: True
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -135,11 +96,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent,${{ runner.os }}
flags: autogpt-agent
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/original_autogpt/logs/
path: classic/logs/

View File

@@ -11,9 +11,6 @@ on:
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- '!**/*.md'
pull_request:
branches: [ master, dev, release-* ]
@@ -22,9 +19,6 @@ on:
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- '!**/*.md'
defaults:
@@ -35,13 +29,9 @@ defaults:
jobs:
serve-agent-protocol:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [ original_autogpt ]
fail-fast: false
timeout-minutes: 20
env:
min-python-version: '3.10'
min-python-version: '3.12'
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -55,22 +45,22 @@ jobs:
python-version: ${{ env.min-python-version }}
- name: Install Poetry
working-directory: ./classic/${{ matrix.agent-name }}/
run: |
curl -sSL https://install.python-poetry.org | python -
- name: Run regression tests
- name: Install dependencies
run: poetry install
- name: Run smoke tests with direct-benchmark
run: |
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
poetry run agbenchmark --test=WriteFile
poetry run direct-benchmark run \
--strategies one_shot \
--models claude \
--tests ReadFile,WriteFile \
--json
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AGENT_NAME: ${{ matrix.agent-name }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
HELICONE_CACHE_ENABLED: false
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
TELEMETRY_ENVIRONMENT: autogpt-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
NONINTERACTIVE_MODE: "true"
CI: true

View File

@@ -1,17 +1,21 @@
name: Classic - AGBenchmark CI
name: Classic - Direct Benchmark CI
on:
push:
branches: [ master, dev, ci-test* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- 'classic/direct_benchmark/**'
- 'classic/benchmark/agbenchmark/challenges/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- .github/workflows/classic-benchmark-ci.yml
pull_request:
branches: [ master, dev, release-* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- 'classic/direct_benchmark/**'
- 'classic/benchmark/agbenchmark/challenges/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- .github/workflows/classic-benchmark-ci.yml
concurrency:
@@ -23,23 +27,16 @@ defaults:
shell: bash
env:
min-python-version: '3.10'
min-python-version: '3.12'
jobs:
test:
permissions:
contents: read
benchmark-tests:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
defaults:
run:
shell: bash
working-directory: classic/benchmark
working-directory: classic
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -47,71 +44,88 @@ jobs:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
- name: Install dependencies
run: poetry install
- name: Run pytest with coverage
- name: Run basic benchmark tests
run: |
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
echo "Testing ReadFile challenge with one_shot strategy..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests ReadFile \
--json
echo "Testing WriteFile challenge..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Test category filtering
run: |
echo "Testing coding category..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--categories coding \
--tests ReadFile,WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}
- name: Test multiple strategies
run: |
echo "Testing multiple strategies..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot,plan_execute \
--models claude \
--tests ReadFile \
--parallel 2 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
self-test-with-agent:
# Run regression tests on maintain challenges
regression-tests:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [forge]
fail-fast: false
timeout-minutes: 20
timeout-minutes: 45
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
defaults:
run:
shell: bash
working-directory: classic
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -126,51 +140,23 @@ jobs:
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python -
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
run: poetry install
- name: Run regression tests
working-directory: classic
run: |
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
set +e # Ignore non-zero exit codes and continue execution
echo "Running the following command: poetry run agbenchmark --maintain --mock"
poetry run agbenchmark --maintain --mock
EXIT_CODE=$?
set -e # Stop ignoring non-zero exit codes
# Check if the exit code was 5, and if so, exit with 0 instead
if [ $EXIT_CODE -eq 5 ]; then
echo "regression_tests.json is empty."
fi
echo "Running the following command: poetry run agbenchmark --mock"
poetry run agbenchmark --mock
echo "Running the following command: poetry run agbenchmark --mock --category=data"
poetry run agbenchmark --mock --category=data
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
poetry run agbenchmark --mock --category=coding
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
# poetry run agbenchmark --test=WriteFile
cd ../benchmark
poetry install
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
export BUILD_SKILL_TREE=true
# poetry run agbenchmark --mock
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
# if [ ! -z "$CHANGED" ]; then
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
# echo "$CHANGED"
# exit 1
# else
# echo "No unstaged changes."
# fi
echo "Running regression tests (previously beaten challenges)..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--maintain \
--parallel 4 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
NONINTERACTIVE_MODE: "true"

View File

@@ -6,13 +6,11 @@ on:
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- '!classic/forge/tests/vcr_cassettes'
concurrency:
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -21,115 +19,38 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic/forge
working-directory: classic
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
runs-on: ubuntu-latest
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
- name: Start MinIO service
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Checkout cassettes
if: ${{ startsWith(github.event_name, 'pull_request') }}
env:
PR_BASE: ${{ github.event.pull_request.base.ref }}
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
cassette_base_branch="${PR_BASE}"
cd tests/vcr_cassettes
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
cassette_base_branch="master"
fi
if git ls-remote --exit-code --heads origin $cassette_branch ; then
git fetch origin $cassette_branch
git fetch origin $cassette_base_branch
git checkout $cassette_branch
# Pick non-conflicting cassette updates from the base branch
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
echo "Using cassettes from mirror branch '$cassette_branch'," \
"synced to upstream branch '$cassette_base_branch'."
else
git checkout -b $cassette_branch
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
"Using cassettes from '$cassette_base_branch'."
fi
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.12"
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Python dependencies
run: poetry install
@@ -140,12 +61,15 @@ jobs:
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge
forge/forge forge/tests
env:
CI: true
PLAIN_OUTPUT: True
# API keys - tests that need these will skip if not available
# Secrets are not available to fork PRs (GitHub security feature)
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -159,85 +83,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}
- id: setup_git_auth
name: Set up git token authentication
# Cassettes may be pushed even when tests fail
if: success() || failure()
run: |
config_key="http.${{ github.server_url }}/.extraheader"
if [ "${{ runner.os }}" = 'macOS' ]; then
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
else
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
fi
git config "$config_key" \
"Authorization: Basic $base64_pat"
cd tests/vcr_cassettes
git config "$config_key" \
"Authorization: Basic $base64_pat"
echo "config_key=$config_key" >> $GITHUB_OUTPUT
- id: push_cassettes
name: Push updated cassettes
# For pull requests, push updated cassettes even when tests fail
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
env:
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
is_pull_request=true
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
else
cassette_branch="${{ github.ref_name }}"
fi
cd tests/vcr_cassettes
# Commit & push changes to cassettes if any
if ! git diff --quiet; then
git add .
git commit -m "Auto-update cassettes"
git push origin HEAD:$cassette_branch
if [ ! $is_pull_request ]; then
cd ../..
git add tests/vcr_cassettes
git commit -m "Update cassette submodule"
git push origin HEAD:$cassette_branch
fi
echo "updated=true" >> $GITHUB_OUTPUT
else
echo "updated=false" >> $GITHUB_OUTPUT
echo "No cassette changes to commit"
fi
- name: Post Set up git token auth
if: steps.setup_git_auth.outcome == 'success'
run: |
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
- name: Apply "behaviour change" label and comment on PR
if: ${{ startsWith(github.event_name, 'pull_request') }}
run: |
PR_NUMBER="${{ github.event.pull_request.number }}"
TOKEN="${{ secrets.PAT_REVIEW }}"
REPO="${{ github.repository }}"
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
echo "Adding label and comment..."
echo $TOKEN | gh auth login --with-token
gh issue edit $PR_NUMBER --add-label "behaviour change"
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
fi
flags: forge
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/forge/logs/
path: classic/logs/

View File

@@ -1,60 +0,0 @@
name: Classic - Frontend CI/CD
on:
push:
branches:
- master
- dev
- 'ci-test*' # This will match any branch that starts with "ci-test"
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
pull_request:
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
jobs:
build:
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
steps:
- name: Checkout Repo
uses: actions/checkout@v4
- name: Setup Flutter
uses: subosito/flutter-action@v2
with:
flutter-version: '3.13.2'
- name: Build Flutter to Web
run: |
cd classic/frontend
flutter build web --base-href /app/
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
# if: github.event_name == 'push'
# run: |
# git config --local user.email "action@github.com"
# git config --local user.name "GitHub Action"
# git add classic/frontend/build/web
# git checkout -B ${{ env.BUILD_BRANCH }}
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
# git push -f origin ${{ env.BUILD_BRANCH }}
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v7
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}
branch: ${{ env.BUILD_BRANCH }}
delete-branch: true
title: "Update frontend build in `${{ github.ref_name }}`"
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
commit-message: "Update frontend build based on commit ${{ github.sha }}"

View File

@@ -7,7 +7,9 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
@@ -16,7 +18,9 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
@@ -27,44 +31,13 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
jobs:
get-changed-parts:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- id: changes-in
name: Determine affected subprojects
uses: dorny/paths-filter@v3
with:
filters: |
original_autogpt:
- classic/original_autogpt/autogpt/**
- classic/original_autogpt/tests/**
- classic/original_autogpt/poetry.lock
forge:
- classic/forge/forge/**
- classic/forge/tests/**
- classic/forge/poetry.lock
benchmark:
- classic/benchmark/agbenchmark/**
- classic/benchmark/tests/**
- classic/benchmark/poetry.lock
outputs:
changed-parts: ${{ steps.changes-in.outputs.changes }}
lint:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
min-python-version: "3.12"
steps:
- name: Checkout repository
@@ -81,42 +54,31 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C classic/${{ matrix.sub-package }} install
run: poetry install
# Lint
- name: Lint (isort)
run: poetry run isort --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Black)
if: success() || failure()
run: poetry run black --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Flake8)
if: success() || failure()
run: poetry run flake8 .
working-directory: classic/${{ matrix.sub-package }}
types:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
min-python-version: "3.12"
steps:
- name: Checkout repository
@@ -133,19 +95,16 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C classic/${{ matrix.sub-package }} install
run: poetry install
# Typecheck
- name: Typecheck
if: success() || failure()
run: poetry run pyright
working-directory: classic/${{ matrix.sub-package }}

10
.gitignore vendored
View File

@@ -3,6 +3,7 @@
classic/original_autogpt/keys.py
classic/original_autogpt/*.json
auto_gpt_workspace/*
.autogpt/
*.mpeg
.env
# Root .env files
@@ -159,6 +160,10 @@ CURRENT_BULLETIN.md
# AgBenchmark
classic/benchmark/agbenchmark/reports/
classic/reports/
classic/direct_benchmark/reports/
classic/.benchmark_workspaces/
classic/direct_benchmark/.benchmark_workspaces/
# Nodejs
package-lock.json
@@ -177,7 +182,10 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
**/.claude/settings.local.json
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next
# Test database
test.db

3
.gitmodules vendored
View File

@@ -1,3 +0,0 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes

View File

@@ -43,29 +43,10 @@ repos:
pass_filenames: false
- id: poetry-install
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
entry: poetry -C classic/original_autogpt install
# include forge source (since it's a path dependency)
files: ^classic/(original_autogpt|forge)/poetry\.lock$
types: [file]
language: system
pass_filenames: false
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: poetry -C classic/forge install
files: ^classic/forge/poetry\.lock$
types: [file]
language: system
pass_filenames: false
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: poetry -C classic/benchmark install
files: ^classic/benchmark/poetry\.lock$
name: Check & Install dependencies - Classic
alias: poetry-install-classic
entry: poetry -C classic install
files: ^classic/poetry\.lock$
types: [file]
language: system
pass_filenames: false
@@ -116,26 +97,10 @@ repos:
language: system
- id: isort
name: Lint (isort) - Classic - AutoGPT
alias: isort-classic-autogpt
entry: poetry -P classic/original_autogpt run isort -p autogpt
files: ^classic/original_autogpt/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Forge
alias: isort-classic-forge
entry: poetry -P classic/forge run isort -p forge
files: ^classic/forge/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Benchmark
alias: isort-classic-benchmark
entry: poetry -P classic/benchmark run isort -p agbenchmark
files: ^classic/benchmark/
name: Lint (isort) - Classic
alias: isort-classic
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
files: ^classic/(original_autogpt|forge|direct_benchmark)/
types: [file, python]
language: system
@@ -149,26 +114,13 @@ repos:
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
# To have flake8 load the config of the individual subprojects, we have to call
# them separately.
# Use consolidated flake8 config at classic/.flake8
hooks:
- id: flake8
name: Lint (Flake8) - Classic - AutoGPT
alias: flake8-classic-autogpt
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
args: [--config=classic/original_autogpt/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Forge
alias: flake8-classic-forge
files: ^classic/forge/(forge|tests)/
args: [--config=classic/forge/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Benchmark
alias: flake8-classic-benchmark
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=classic/benchmark/.flake8]
name: Lint (Flake8) - Classic
alias: flake8-classic
files: ^classic/(original_autogpt|forge|direct_benchmark)/
args: [--config=classic/.flake8]
- repo: local
hooks:
@@ -204,29 +156,10 @@ repos:
pass_filenames: false
- id: pyright
name: Typecheck - Classic - AutoGPT
alias: pyright-classic-autogpt
entry: poetry -C classic/original_autogpt run pyright
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Forge
alias: pyright-classic-forge
entry: poetry -C classic/forge run pyright
files: ^classic/forge/(forge/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Benchmark
alias: pyright-classic-benchmark
entry: poetry -C classic/benchmark run pyright
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
name: Typecheck - Classic
alias: pyright-classic
entry: poetry -C classic run pyright
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
types: [file]
language: system
pass_filenames: false

View File

@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
### Updated Setup Instructions:
We've moved to a fully maintained and regularly updated documentation site.
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
This tutorial assumes you have Docker, VSCode, git and npm installed.

View File

@@ -1,368 +0,0 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import os
import uuid
from typing import Any
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -1,344 +0,0 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
prisma_client: Prisma | None,
) -> None:
"""Update tool message in database.
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises:
ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
"""
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
updated_count = await prisma_client.chatmessage.update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={"content": content},
)
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises:
ToolMessageUpdateError: If the database update fails. The task will be
marked as failed instead of completed to avoid inconsistent state.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -44,48 +44,6 @@ class ChatConfig(BaseSettings):
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
@@ -124,14 +82,6 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -52,10 +52,6 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
class StreamFinish(StreamBaseResponse):

View File

@@ -1,23 +1,19 @@
"""Chat API routes for chat session management and streaming via SSE."""
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -59,15 +55,6 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -76,7 +63,6 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -95,14 +81,6 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -188,14 +166,13 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
SessionDetailResponse: Details for the requested session, or None if not found.
"""
session = await get_chat_session(session_id, user_id)
@@ -203,28 +180,11 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
id=session.session_id,
@@ -232,7 +192,6 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@@ -252,112 +211,49 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
session = await _validate_and_get_session(session_id, user_id)
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
subscriber_queue = None
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass # Client disconnected - background task continues
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -470,251 +366,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@router.get("/config/ttl", status_code=200)
async def get_ttl_config() -> dict:
"""
Get the stream TTL configuration.
Returns the Time-To-Live settings for chat streams, which determines
how long clients can reconnect to an active stream.
Returns:
dict: TTL configuration with seconds and milliseconds values.
"""
return {
"stream_ttl_seconds": config.stream_ttl,
"stream_ttl_ms": config.stream_ttl * 1000,
}
# ========== Health Check ==========

File diff suppressed because it is too large Load Diff

View File

@@ -1,704 +0,0 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay and real-time delivery
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream (XREAD)
2. Listen for live updates via blocking XREAD
3. No in-memory state required on the subscribing pod
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamError, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
# Track listener tasks per subscriber queue for cleanup
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
# If the queue is full and doesn't drain within this time, send an overflow error
QUEUE_PUT_TIMEOUT = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
COMPLETE_TASK_SCRIPT = """
local current = redis.call("HGET", KEYS[1], "status")
if current == "running" then
redis.call("HSET", KEYS[1], "status", ARGV[1])
return 1
end
return 0
"""
@dataclass
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{config.task_op_prefix}{operation_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance (metadata only)
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store metadata in Redis
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream.
All delivery is via Redis Streams - no in-memory state.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID
"""
chunk_json = chunk.model_dump_json()
message_id = "0-0"
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
except Exception as e:
logger.error(
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
This approach avoids the duplicate message issue that can occur with pub/sub
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
"""
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for task {task_id}"
)
break
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
await asyncio.wait_for(
subscriber_queue.put(chunk),
timeout=QUEUE_PUT_TIMEOUT,
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
except asyncio.TimeoutError:
logger.warning(
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
overflow_error = StreamError(
errorText="Message delivery timeout - some messages may have been missed",
code="QUEUE_OVERFLOW",
details={
"last_delivered_id": last_delivered_id,
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
},
)
subscriber_queue.put_nowait(overflow_error)
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for task {task_id}, "
"queue completely blocked"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
_listener_tasks.pop(queue_id, None)
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> bool:
"""Mark a task as completed and publish finish event.
This is idempotent - calling multiple times with the same task_id is safe.
Uses atomic compare-and-swap via Lua script to prevent race conditions.
Status is updated first (source of truth), then finish event is published (best-effort).
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
Returns:
True if task was newly marked completed, False if already completed/failed
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
# Atomic compare-and-swap: only update if status is "running"
# This prevents race conditions when multiple callers try to complete simultaneously
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish finish event for task {task_id}: {e}. "
"Listeners will detect completion via status polling."
)
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
return True
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if not task_id:
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None
# Note: Redis client uses decode_responses=True, so keys/values are strings
return ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
async def get_task_with_expiry_info(
task_id: str,
) -> tuple[ActiveTask | None, str | None]:
"""Get a task by its ID with expiration detection.
Returns (task, error_code) where error_code is:
- None if task found
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
- "TASK_NOT_FOUND" if neither exists
Args:
task_id: Task ID to look up
Returns:
Tuple of (ActiveTask or None, error_code or None)
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
stream_key = _get_task_stream_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
# Check if stream still has data (metadata expired but stream hasn't)
stream_len = await redis.xlen(stream_key)
if stream_len > 0:
return None, "TASK_EXPIRED"
return None, "TASK_NOT_FOUND"
# Note: Redis client uses decode_responses=True, so keys/values are strings
return (
ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
),
None,
)
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
redis = await get_redis_async()
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
if task_session_id == session_id and task_status == "running":
# Validate ownership - if task has an owner, requester must match
if task_user_id and user_id != task_user_id:
continue
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
if cursor == 0:
break
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
# Map response types to their corresponding classes
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
ResponseType.ERROR.value: StreamError,
ResponseType.USAGE.value: StreamUsage,
ResponseType.HEARTBEAT.value: StreamHeartbeat,
}
chunk_type = chunk_data.get("type")
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
if chunk_class is None:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
try:
return chunk_class(**chunk_data)
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to track
"""
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
Cancels the XREAD-based listener task associated with this subscriber queue
to prevent resource leaks.
Args:
task_id: Task ID
subscriber_queue: The subscriber's queue used to look up the listener task
"""
queue_id = id(subscriber_queue)
listener_entry = _listener_tasks.pop(queue_id, None)
if listener_entry is None:
logger.debug(
f"No listener task found for task {task_id} queue {queue_id} "
"(may have already completed)"
)
return
stored_task_id, listener_task = listener_entry
if stored_task_id != task_id:
logger.warning(
f"Task ID mismatch in unsubscribe: expected {task_id}, "
f"found {stored_task_id}"
)
if listener_task.done():
logger.debug(f"Listener task for task {task_id} already completed")
return
# Cancel the listener task
listener_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(listener_task, timeout=5.0)
except asyncio.CancelledError:
# Expected - the task was successfully cancelled
pass
except asyncio.TimeoutError:
logger.warning(
f"Timeout waiting for listener task cancellation for task {task_id}"
)
except Exception as e:
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
logger.debug(f"Successfully unsubscribed from task {task_id}")

View File

@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
@@ -35,7 +34,6 @@ logger = logging.getLogger(__name__)
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),

View File

@@ -2,58 +2,30 @@
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
save_agent_to_library,
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
# Core functions
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",
"get_agent_as_json",
"json_to_graph",
# Exceptions
"AgentGeneratorNotConfiguredError",
# Service
"is_external_service_configured",
"check_external_service_health",
# Error handling
"get_user_message_for_error",
]

View File

@@ -1,25 +1,13 @@
"""Core agent generation functions."""
import logging
import re
import uuid
from typing import Any, NotRequired, TypedDict
from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.data.graph import (
Graph,
Link,
Node,
create_graph,
get_graph,
get_graph_all_versions,
get_store_listed_graphs,
)
from backend.util.exceptions import DatabaseError, NotFoundError
from backend.data.graph import Graph, Link, Node, create_graph
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
@@ -28,74 +16,6 @@ from .service import (
logger = logging.getLogger(__name__)
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
class ExecutionSummary(TypedDict):
"""Summary of a single execution for quality assessment."""
status: str
correctness_score: NotRequired[float]
activity_summary: NotRequired[str]
class LibraryAgentSummary(TypedDict):
"""Summary of a library agent for sub-agent composition.
Includes recent executions to help the LLM decide whether to use this agent.
Each execution shows status, correctness_score (0-1), and activity_summary.
"""
graph_id: str
graph_version: int
name: str
description: str
input_schema: dict[str, Any]
output_schema: dict[str, Any]
recent_executions: NotRequired[list[ExecutionSummary]]
class MarketplaceAgentSummary(TypedDict):
"""Summary of a marketplace agent for sub-agent composition."""
name: str
description: str
sub_heading: str
creator: str
is_marketplace_agent: bool
class DecompositionStep(TypedDict, total=False):
"""A single step in decomposed instructions."""
description: str
action: str
block_name: str
tool: str
name: str
class DecompositionResult(TypedDict, total=False):
"""Result from decompose_goal - can be instructions, questions, or error."""
type: str
steps: list[DecompositionStep]
questions: list[dict[str, Any]]
error: str
error_type: str
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: list[AgentSummary] | list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
@@ -116,422 +36,15 @@ def _check_service_configured() -> None:
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
def extract_uuids_from_text(text: str) -> list[str]:
"""Extract all UUID v4 strings from text.
Args:
text: Text that may contain UUIDs (e.g., user's goal description)
Returns:
List of unique UUIDs found in the text (lowercase)
"""
matches = _UUID_PATTERN.findall(text)
return list({m.lower() for m in matches})
async def get_library_agent_by_id(
user_id: str, agent_id: str
) -> LibraryAgentSummary | None:
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
This function tries multiple lookup strategies:
1. First tries to find by graph_id (AgentGraph primary key)
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
This handles both cases:
- User provides graph_id (e.g., from AgentExecutorBlock)
- User provides library agent ID (e.g., from library URL)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
LibraryAgentSummary if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
except DatabaseError:
raise
except Exception as e:
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
exclude_graph_id: str | None = None,
max_results: int = 15,
) -> list[LibraryAgentSummary]:
"""Fetch user's library agents formatted for Agent Generator.
Uses search-based fetching to return relevant agents instead of all agents.
This is more scalable for users with large libraries.
Includes recent_executions list to help the LLM assess agent quality:
- Each execution has status, correctness_score (0-1), and activity_summary
- This gives the LLM concrete examples of recent performance
Args:
user_id: The user ID
search_query: Optional search term to find relevant agents (user's goal/description)
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
max_results: Maximum number of agents to return (default 15)
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
try:
response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query,
page=1,
page_size=max_results,
include_executions=True,
)
results: list[LibraryAgentSummary] = []
for agent in response.agents:
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
continue
summary = LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
if agent.recent_executions:
exec_summaries: list[ExecutionSummary] = []
for ex in agent.recent_executions:
exec_sum = ExecutionSummary(status=ex.status)
if ex.correctness_score is not None:
exec_sum["correctness_score"] = ex.correctness_score
if ex.activity_summary:
exec_sum["activity_summary"] = ex.activity_summary
exec_summaries.append(exec_sum)
summary["recent_executions"] = exec_summaries
results.append(summary)
return results
except DatabaseError:
raise
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
return []
async def search_marketplace_agents_for_generation(
search_query: str,
max_results: int = 10,
) -> list[LibraryAgentSummary]:
"""Search marketplace agents formatted for Agent Generator.
Fetches marketplace agents and their full schemas so they can be used
as sub-agents in generated workflows.
Args:
search_query: Search term to find relevant public agents
max_results: Maximum number of agents to return (default 10)
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
try:
response = await store_db.get_store_agents(
search_query=search_query,
page=1,
page_size=max_results,
)
agents_with_graphs = [
agent for agent in response.agents if agent.agent_graph_id
]
if not agents_with_graphs:
return []
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
graphs = await get_store_listed_graphs(*graph_ids)
results: list[LibraryAgentSummary] = []
for agent in agents_with_graphs:
graph_id = agent.agent_graph_id
if graph_id and graph_id in graphs:
graph = graphs[graph_id]
results.append(
LibraryAgentSummary(
graph_id=graph.id,
graph_version=graph.version,
name=agent.agent_name,
description=agent.description,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
)
return results
except Exception as e:
logger.warning(f"Failed to search marketplace agents: {e}")
return []
async def get_all_relevant_agents_for_generation(
user_id: str,
search_query: str | None = None,
exclude_graph_id: str | None = None,
include_library: bool = True,
include_marketplace: bool = True,
max_library_results: int = 15,
max_marketplace_results: int = 10,
) -> list[AgentSummary]:
"""Fetch relevant agents from library and/or marketplace.
Searches both user's library and marketplace by default.
Explicitly mentioned UUIDs in the search query are always looked up.
Args:
user_id: The user ID
search_query: Search term to find relevant agents (user's goal/description)
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
include_library: Whether to search user's library (default True)
include_marketplace: Whether to also search marketplace (default True)
max_library_results: Max library agents to return (default 15)
max_marketplace_results: Max marketplace agents to return (default 10)
Returns:
List of AgentSummary with full schemas (both library and marketplace agents)
"""
agents: list[AgentSummary] = []
seen_graph_ids: set[str] = set()
if search_query:
mentioned_uuids = extract_uuids_from_text(search_query)
for graph_id in mentioned_uuids:
if graph_id == exclude_graph_id:
continue
agent = await get_library_agent_by_graph_id(user_id, graph_id)
agent_graph_id = agent.get("graph_id") if agent else None
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(agent_graph_id)
logger.debug(
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
)
if include_library:
library_agents = await get_library_agents_for_generation(
user_id=user_id,
search_query=search_query,
exclude_graph_id=exclude_graph_id,
max_results=max_library_results,
)
for agent in library_agents:
graph_id = agent.get("graph_id")
if graph_id and graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(graph_id)
if include_marketplace and search_query:
marketplace_agents = await search_marketplace_agents_for_generation(
search_query=search_query,
max_results=max_marketplace_results,
)
for agent in marketplace_agents:
graph_id = agent.get("graph_id")
if graph_id and graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(graph_id)
return agents
def extract_search_terms_from_steps(
decomposition_result: DecompositionResult | dict[str, Any],
) -> list[str]:
"""Extract search terms from decomposed instruction steps.
Analyzes the decomposition result to extract relevant keywords
for additional library agent searches.
Args:
decomposition_result: Result from decompose_goal containing steps
Returns:
List of unique search terms extracted from steps
"""
search_terms: list[str] = []
if decomposition_result.get("type") != "instructions":
return search_terms
steps = decomposition_result.get("steps", [])
if not steps:
return search_terms
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
for step in steps:
for key in step_keys:
value = step.get(key) # type: ignore[union-attr]
if isinstance(value, str) and len(value) > 3:
search_terms.append(value)
seen: set[str] = set()
unique_terms: list[str] = []
for term in search_terms:
term_lower = term.lower()
if term_lower not in seen:
seen.add(term_lower)
unique_terms.append(term)
return unique_terms
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
) -> list[AgentSummary] | list[dict[str, Any]]:
"""Enrich library agents list with additional searches based on decomposed steps.
This implements two-phase search: after decomposition, we search for additional
relevant agents based on the specific steps identified.
Args:
user_id: The user ID
decomposition_result: Result from decompose_goal containing steps
existing_agents: Already fetched library agents from initial search
exclude_graph_id: Optional graph ID to exclude
include_marketplace: Whether to also search marketplace
max_additional_results: Max additional agents per search term (default 10)
Returns:
Combined list of library agents (existing + newly discovered)
"""
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return existing_agents
existing_ids: set[str] = set()
existing_names: set[str] = set()
for agent in existing_agents:
agent_name = agent.get("name")
if agent_name and isinstance(agent_name, str):
existing_names.add(agent_name.lower())
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
for term in search_terms[:3]:
try:
additional_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=term,
exclude_graph_id=exclude_graph_id,
include_marketplace=include_marketplace,
max_library_results=max_additional_results,
max_marketplace_results=5,
)
for agent in additional_agents:
agent_name = agent.get("name")
if not agent_name or not isinstance(agent_name, str):
continue
agent_name_lower = agent_name.lower()
if agent_name_lower in existing_names:
continue
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id and graph_id in existing_ids:
continue
all_agents.append(agent)
existing_names.add(agent_name_lower)
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
except DatabaseError:
logger.error(f"Database error searching for agents with term '{term}'")
raise
except Exception as e:
logger.warning(
f"Failed to search for additional agents with term '{term}': {e}"
)
logger.debug(
f"Enriched library agents: {len(existing_agents)} initial + "
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
)
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: list[AgentSummary] | None = None,
) -> DecompositionResult | None:
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
@@ -541,47 +54,29 @@ async def decompose_goal(
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
return await decompose_goal_external(description, context)
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
result = await generate_agent_external(instructions)
if result:
# Check if it's an error response - pass through as-is
if isinstance(result, dict) and result.get("type") == "error":
return result
# Ensure required fields for successful agent generation
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
@@ -591,12 +86,6 @@ async def generate_agent(
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
pass
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
"""Convert agent JSON dict to Graph model.
@@ -605,55 +94,25 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
Returns:
Graph ready for saving
Raises:
AgentJsonValidationError: If required fields are missing from nodes or links
"""
nodes = []
for idx, n in enumerate(agent_json.get("nodes", [])):
block_id = n.get("block_id")
if not block_id:
node_id = n.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Node '{node_id}' is missing required field 'block_id'"
)
for n in agent_json.get("nodes", []):
node = Node(
id=n.get("id", str(uuid.uuid4())),
block_id=block_id,
block_id=n["block_id"],
input_default=n.get("input_default", {}),
metadata=n.get("metadata", {}),
)
nodes.append(node)
links = []
for idx, link_data in enumerate(agent_json.get("links", [])):
source_id = link_data.get("source_id")
sink_id = link_data.get("sink_id")
source_name = link_data.get("source_name")
sink_name = link_data.get("sink_name")
missing_fields = []
if not source_id:
missing_fields.append("source_id")
if not sink_id:
missing_fields.append("sink_id")
if not source_name:
missing_fields.append("source_name")
if not sink_name:
missing_fields.append("sink_name")
if missing_fields:
link_id = link_data.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
)
for link_data in agent_json.get("links", []):
link = Link(
id=link_data.get("id", str(uuid.uuid4())),
source_id=source_id,
sink_id=sink_id,
source_name=source_name,
sink_name=sink_name,
source_id=link_data["source_id"],
sink_id=link_data["sink_id"],
source_name=link_data["source_name"],
sink_name=link_data["sink_name"],
is_static=link_data.get("is_static", False),
)
links.append(link)
@@ -674,40 +133,22 @@ def _reassign_node_ids(graph: Graph) -> None:
This is needed when creating a new version to avoid unique constraint violations.
"""
# Create mapping from old node IDs to new UUIDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
# Reassign node IDs
for node in graph.nodes:
node.id = id_map[node.id]
# Update link references to use new node IDs
for link in graph.links:
link.id = str(uuid.uuid4())
link.id = str(uuid.uuid4()) # Also give links new IDs
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
"""Populate user_id in AgentExecutorBlock nodes.
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
This function fills in the actual user_id so sub-agents run with correct permissions.
Args:
agent_json: Agent JSON dict (modified in place)
user_id: User ID to set
"""
for node in agent_json.get("nodes", []):
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
input_default = node.get("input_default") or {}
if not input_default.get("user_id"):
input_default["user_id"] = user_id
node["input_default"] = input_default
logger.debug(
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
)
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
@@ -721,27 +162,33 @@ async def save_agent_to_library(
Returns:
Tuple of (created Graph, LibraryAgent)
"""
# Populate user_id in AgentExecutorBlock nodes before conversion
_populate_agent_executor_user_ids(agent_json, user_id)
from backend.data.graph import get_graph_all_versions
graph = json_to_graph(agent_json)
if is_update:
# For updates, keep the same graph ID but increment version
# and reassign node/link IDs to avoid conflicts
if graph.id:
existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
# Reassign node IDs (but keep graph ID the same)
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
# For new agents, always generate a fresh UUID to avoid collisions
graph.id = str(uuid.uuid4())
graph.version = 1
# Reassign all node IDs as well
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
# Save to database
created_graph = await create_graph(graph, user_id)
# Add to user's library (or update existing library agent)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
@@ -752,15 +199,26 @@ async def save_agent_to_library(
return created_graph, library_agents[0]
def graph_to_json(graph: Graph) -> dict[str, Any]:
"""Convert a Graph object to JSON format for the agent generator.
async def get_agent_as_json(
graph_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
graph: Graph object to convert
graph_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict
Agent as JSON dict or None if not found
"""
from backend.data.graph import get_graph
# Try to get the graph (version=None gets the active version)
graph = await get_graph(graph_id, version=None, user_id=user_id)
if not graph:
return None
# Convert to JSON format
nodes = []
for node in graph.nodes:
nodes.append(
@@ -797,41 +255,8 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
}
async def get_agent_as_json(
agent_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
agent_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict or None if not found
"""
graph = await get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
pass
if not graph:
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -843,57 +268,14 @@ async def generate_agent_patch(
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(update_request, current_agent)

View File

@@ -1,43 +1,11 @@
"""Error handling utilities for agent generator."""
import re
def _sanitize_error_details(details: str) -> str:
"""Sanitize error details to remove sensitive information.
Strips common patterns that could expose internal system info:
- File paths (Unix and Windows)
- Database connection strings
- URLs with credentials
- Stack trace internals
Args:
details: Raw error details string
Returns:
Sanitized error details safe for user display
"""
sanitized = re.sub(
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
)
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
sanitized = re.sub(
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
)
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
sanitized = re.sub(r", line \d+", "", sanitized)
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
return sanitized.strip()
def get_user_message_for_error(
error_type: str,
operation: str = "process the request",
llm_parse_message: str | None = None,
validation_message: str | None = None,
error_details: str | None = None,
) -> str:
"""Get a user-friendly error message based on error type.
@@ -51,45 +19,25 @@ def get_user_message_for_error(
message (e.g., "analyze the goal", "generate the agent")
llm_parse_message: Custom message for llm_parse_error type
validation_message: Custom message for validation_error type
error_details: Optional additional details about the error
Returns:
User-friendly error message suitable for display to the user
"""
base_message = ""
if error_type == "llm_parse_error":
base_message = (
return (
llm_parse_message
or "The AI had trouble processing this request. Please try again."
)
elif error_type == "validation_error":
base_message = (
return (
validation_message
or "The generated agent failed validation. "
"This usually happens when the agent structure doesn't match "
"what the platform expects. Please try simplifying your goal "
"or breaking it into smaller parts."
or "The request failed validation. Please try rephrasing."
)
elif error_type == "patch_error":
base_message = (
"Failed to apply the changes. The modification couldn't be "
"validated. Please try a different approach or simplify the change."
)
return "Failed to apply the changes. Please try a different approach."
elif error_type in ("timeout", "llm_timeout"):
base_message = (
"The request took too long to process. This can happen with "
"complex agents. Please try again or simplify your goal."
)
return "The request took too long. Please try again."
elif error_type in ("rate_limit", "llm_rate_limit"):
base_message = "The service is currently busy. Please try again in a moment."
return "The service is currently busy. Please try again in a moment."
else:
base_message = f"Failed to {operation}. Please try again."
if error_details:
details = _sanitize_error_details(error_details)
if len(details) > 200:
details = details[:200] + "..."
base_message += f"\n\nTechnical details: {details}"
return base_message
return f"Failed to {operation}. Please try again."

View File

@@ -117,16 +117,13 @@ def _get_client() -> httpx.AsyncClient:
async def decompose_goal_external(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
description: str, context: str = ""
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
@@ -139,12 +136,11 @@ async def decompose_goal_external(
"""
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
# Build the request payload
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
try:
response = await client.post("/api/decompose-description", json=payload)
@@ -211,46 +207,21 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
Agent JSON dict on success, or error dict {"type": "error", ...} on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response.raise_for_status()
data = response.json()
@@ -258,7 +229,8 @@ async def generate_agent_external(
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
f"Agent Generator generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
@@ -279,52 +251,27 @@ async def generate_agent_external(
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
Updated agent JSON, clarifying questions dict, or error dict on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response.raise_for_status()
data = response.json()
@@ -368,77 +315,6 @@ async def generate_agent_patch_external(
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.

View File

@@ -1,7 +1,6 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
import logging
import re
from typing import Literal
from backend.api.features.library import db as library_db
@@ -20,85 +19,6 @@ logger = logging.getLogger(__name__)
SearchSource = Literal["marketplace", "library"]
_UUID_PATTERN = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
re.IGNORECASE,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
async def search_agents(
query: str,
@@ -149,37 +69,29 @@ async def search_agents(
is_featured=False,
)
)
else:
if _is_uuid(query):
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
if agent:
agents.append(agent)
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
else: # library
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass

View File

@@ -8,9 +8,7 @@ from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -18,7 +16,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -99,10 +96,6 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -110,24 +103,9 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
library_agents = None
if user_id:
try:
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=description,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
# Step 1: Decompose goal into steps
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
decomposition_result = await decompose_goal(description, context)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -146,6 +124,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if the result is an error from the external service
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
@@ -165,6 +144,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if LLM returned clarifying questions
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
@@ -183,6 +163,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check for unachievable/vague goals
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
@@ -209,27 +190,9 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to enrich library agents from steps: {e}")
# Step 2: Generate agent JSON (external service handles fixing and validation)
try:
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
agent_json = await generate_agent(decomposition_result)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -248,6 +211,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
@@ -255,12 +219,7 @@ class CreateAgentTool(BaseTool):
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message=(
"I wasn't able to create a valid agent for this request. "
"The generated workflow had some structural issues. "
"Please try simplifying your goal or breaking it into smaller steps."
),
error_details=error_msg,
validation_message="The generated agent failed validation. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
@@ -273,24 +232,12 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
@@ -305,6 +252,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Save to library
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -322,7 +270,7 @@ class CreateAgentTool(BaseTool):
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -1,337 +0,0 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@property
def name(self) -> str:
return "customize_agent"
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
),
},
"modifications": {
"type": "string",
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["agent_id", "modifications"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
error="invalid_agent_id_format",
session_id=session_id,
)
creator_username, agent_slug = parts
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -9,7 +9,6 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -17,7 +16,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -105,10 +103,6 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -123,6 +117,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Step 1: Fetch current agent
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
@@ -132,34 +127,14 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
library_agents = None
if user_id:
try:
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
# Build the update request with context
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
# Step 2: Generate updated agent (external service handles fixing and validation)
try:
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
result = await generate_agent_patch(update_request, current_agent)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -178,19 +153,6 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
@@ -200,7 +162,6 @@ class EditAgentTool(BaseTool):
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
@@ -214,6 +175,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if LLM returned clarifying questions
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
@@ -232,6 +194,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Result is the updated agent JSON
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
@@ -239,6 +202,7 @@ class EditAgentTool(BaseTool):
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
@@ -254,6 +218,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Save to library (creates a new version)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -271,7 +236,7 @@ class EditAgentTool(BaseTool):
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -38,8 +38,6 @@ class ResponseType(str, Enum):
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Base response model
@@ -70,10 +68,6 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
)
class AgentsFoundResponse(ToolResponseBase):
@@ -200,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
unrecognized_fields: list[str] = Field(
description="List of input field names that were not recognized"
)
inputs: dict[str, Any] = Field(
description="The agent's valid input schema for reference"
)
graph_id: str | None = None
graph_version: int | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
@@ -372,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
@@ -404,20 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -30,7 +30,6 @@ from .models import (
ErrorResponse,
ExecutionOptions,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide

View File

@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
# Should return error about missing schedule_name
assert result_data.get("type") == "error"
assert "schedule_name" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"""Test that run_agent returns input_validation_error for unknown input fields."""
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
tool = RunAgentTool()
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute with unknown input field names
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={
"unknown_field": "some value",
"another_unknown": "another value",
},
session=session,
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return input_validation_error type with unrecognized fields
assert result_data.get("type") == "input_validation_error"
assert "unrecognized_fields" in result_data
assert set(result_data["unrecognized_fields"]) == {
"another_unknown",
"unknown_field",
}
assert "inputs" in result_data # Contains the valid schema
assert "Agent was not executed" in result_data["message"]

View File

@@ -5,8 +5,6 @@ import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
@@ -77,22 +75,15 @@ class RunBlockTool(BaseTool):
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
@@ -105,33 +96,14 @@ class RunBlockTool(BaseTool):
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
@@ -145,8 +117,8 @@ class RunBlockTool(BaseTool):
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
@@ -214,9 +186,10 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
user_id, block
)
if missing_credentials:

View File

@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError
@@ -266,14 +266,13 @@ async def match_user_credentials_to_graph(
credential_requirements,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes
# Find first matching credential by provider and type
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and _credential_has_required_scopes(cred, credential_requirements)
),
None,
)
@@ -297,17 +296,10 @@ async def match_user_credentials_to_graph(
f"{credential_field_name} (validation failed: {e})"
)
else:
# Build a helpful error message including scope requirements
error_parts = [
f"provider in {list(credential_requirements.provider)}",
f"type in {list(credential_requirements.supported_types)}",
]
if credential_requirements.required_scopes:
error_parts.append(
f"scopes including {list(credential_requirements.required_scopes)}"
)
missing_creds.append(
f"{credential_field_name} (requires {', '.join(error_parts)})"
f"{credential_field_name} "
f"(requires provider in {list(credential_requirements.provider)}, "
f"type in {list(credential_requirements.supported_types)})"
)
logger.info(
@@ -317,28 +309,6 @@ async def match_user_credentials_to_graph(
return graph_credentials_inputs, missing_creds
def _credential_has_required_scopes(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -39,7 +39,6 @@ async def list_library_agents(
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
page: int = 1,
page_size: int = 50,
include_executions: bool = False,
) -> library_model.LibraryAgentResponse:
"""
Retrieves a paginated list of LibraryAgent records for a given user.
@@ -50,9 +49,6 @@ async def list_library_agents(
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
page: Current page (1-indexed).
page_size: Number of items per page.
include_executions: Whether to include execution data for status calculation.
Defaults to False for performance (UI fetches status separately).
Set to True when accurate status/metrics are needed (e.g., agent generator).
Returns:
A LibraryAgentResponse containing the list of agents and pagination details.
@@ -80,6 +76,7 @@ async def list_library_agents(
"isArchived": False,
}
# Build search filter if applicable
if search_term:
where_clause["OR"] = [
{
@@ -96,6 +93,7 @@ async def list_library_agents(
},
]
# Determine sorting
order_by: prisma.types.LibraryAgentOrderByInput | None = None
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
@@ -107,7 +105,7 @@ async def list_library_agents(
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include=library_agent_include(
user_id, include_nodes=False, include_executions=include_executions
user_id, include_nodes=False, include_executions=False
),
order=order_by,
skip=(page - 1) * page_size,

View File

@@ -9,7 +9,6 @@ import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.util.json import loads as json_loads
from backend.util.models import Pagination
if TYPE_CHECKING:
@@ -17,10 +16,10 @@ if TYPE_CHECKING:
class LibraryAgentStatus(str, Enum):
COMPLETED = "COMPLETED"
HEALTHY = "HEALTHY"
WAITING = "WAITING"
ERROR = "ERROR"
COMPLETED = "COMPLETED" # All runs completed
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
WAITING = "WAITING" # Agent is queued or waiting to start
ERROR = "ERROR" # Agent is in an error state
class MarketplaceListingCreator(pydantic.BaseModel):
@@ -40,30 +39,6 @@ class MarketplaceListing(pydantic.BaseModel):
creator: MarketplaceListingCreator
class RecentExecution(pydantic.BaseModel):
"""Summary of a recent execution for quality assessment.
Used by the LLM to understand the agent's recent performance with specific examples
rather than just aggregate statistics.
"""
status: str
correctness_score: float | None = None
activity_summary: str | None = None
def _parse_settings(settings: dict | str | None) -> GraphSettings:
"""Parse settings from database, handling both dict and string formats."""
if settings is None:
return GraphSettings()
try:
if isinstance(settings, str):
settings = json_loads(settings)
return GraphSettings.model_validate(settings)
except Exception:
return GraphSettings()
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
@@ -73,7 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
owner_user_id: str # ID of user who owns/created this agent graph
image_url: str | None
@@ -89,7 +64,7 @@ class LibraryAgent(pydantic.BaseModel):
description: str
instructions: str | None = None
input_schema: dict[str, Any]
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, Any]
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
description="Input schema for credentials required by the agent",
@@ -106,19 +81,25 @@ class LibraryAgent(pydantic.BaseModel):
)
trigger_setup_info: Optional[GraphTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs)
new_output: bool
execution_count: int = 0
success_rate: float | None = None
avg_correctness_score: float | None = None
recent_executions: list[RecentExecution] = pydantic.Field(
default_factory=list,
description="List of recent executions with status, score, and summary",
)
# Whether the user can access the underlying graph
can_access_graph: bool
# Indicates if this agent is the latest version
is_latest_version: bool
# Whether the agent is marked as favorite by the user
is_favorite: bool
# Recommended schedule cron (from marketplace agents)
recommended_schedule_cron: str | None = None
# User-specific settings for this library agent
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
# Marketplace listing information if the agent has been published
marketplace_listing: Optional["MarketplaceListing"] = None
@staticmethod
@@ -142,6 +123,7 @@ class LibraryAgent(pydantic.BaseModel):
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
@@ -154,6 +136,7 @@ class LibraryAgent(pydantic.BaseModel):
creator_name = agent.Creator.name or "Unknown"
creator_image_url = agent.Creator.avatarUrl or ""
# Logic to calculate status and new_output
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
@@ -162,55 +145,13 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = len(executions)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
success_count = sum(
1
for e in executions
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
)
success_rate = (success_count / execution_count) * 100
correctness_scores = []
for e in executions:
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
correctness_scores.append(float(score))
if correctness_scores:
avg_correctness_score = sum(correctness_scores) / len(
correctness_scores
)
recent_executions: list[RecentExecution] = []
for e in executions:
exec_score: float | None = None
exec_summary: str | None = None
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
exec_score = float(score)
summary = e.stats.get("activity_status")
if summary is not None and isinstance(summary, str):
exec_summary = summary
exec_status = (
e.executionStatus.value
if hasattr(e.executionStatus, "value")
else str(e.executionStatus)
)
recent_executions.append(
RecentExecution(
status=exec_status,
correctness_score=exec_score,
activity_summary=exec_summary,
)
)
# Check if user can access the graph
can_access_graph = agent.AgentGraph.userId == agent.userId
# Hard-coded to True until a method to check is implemented
is_latest_version = True
# Build marketplace_listing if available
marketplace_listing_data = None
if store_listing and store_listing.ActiveVersion and profile:
creator_data = MarketplaceListingCreator(
@@ -249,15 +190,11 @@ class LibraryAgent(pydantic.BaseModel):
has_sensitive_action=graph.has_sensitive_action,
trigger_setup_info=graph.trigger_setup_info,
new_output=new_output,
execution_count=execution_count,
success_rate=success_rate,
avg_correctness_score=avg_correctness_score,
recent_executions=recent_executions,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
settings=GraphSettings.model_validate(agent.settings),
marketplace_listing=marketplace_listing_data,
)
@@ -283,15 +220,18 @@ def _calculate_agent_status(
if not executions:
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
# Track how many times each execution status appears
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
new_output = False
for execution in executions:
# Check if there's a completed run more recent than `recent_threshold`
if execution.createdAt >= recent_threshold:
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
new_output = True
status_counts[execution.executionStatus] += 1
# Determine the final status based on counts
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:

View File

@@ -112,7 +112,6 @@ async def get_store_agents(
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
agent_graph_id=agent.get("agentGraphId", ""),
)
store_agents.append(store_agent)
except Exception as e:
@@ -171,7 +170,6 @@ async def get_store_agents(
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.agentGraphId,
)
# Add to the list only if creation was successful
store_agents.append(store_agent)

View File

@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
cleanup_embeddings: list,
):
"""Test unified search pagination works correctly."""
# Use a unique search term to avoid matching other test data
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
# Create multiple items
content_ids = []
for i in range(5):
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text=f"{unique_term} item number {i}",
searchable_text=f"pagination test item number {i}",
metadata={"index": i},
user_id=None,
)
# Get first page
page1_results, total1 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=1,
page_size=2,
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
# Get second page
page2_results, total2 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=2,
page_size=2,

View File

@@ -600,7 +600,6 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa."agentGraphId",
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -660,7 +659,6 @@ async def hybrid_search(
featured,
is_available,
updated_at,
"agentGraphId",
searchable_text,
semantic_score,
lexical_score,

View File

@@ -38,7 +38,6 @@ class StoreAgent(pydantic.BaseModel):
description: str
runs: int
rating: float
agent_graph_id: str
class StoreAgentsResponse(pydantic.BaseModel):

View File

@@ -26,13 +26,11 @@ def test_store_agent():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
@@ -48,7 +46,6 @@ def test_store_agents_response():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(

View File

@@ -82,7 +82,6 @@ def test_get_agents_featured(
description="Featured agent description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-1",
)
],
pagination=store_model.Pagination(
@@ -128,7 +127,6 @@ def test_get_agents_by_creator(
description="Creator agent description",
runs=50,
rating=4.0,
agent_graph_id="test-graph-2",
)
],
pagination=store_model.Pagination(
@@ -174,7 +172,6 @@ def test_get_agents_sorted(
description="Top agent description",
runs=1000,
rating=5.0,
agent_graph_id="test-graph-3",
)
],
pagination=store_model.Pagination(
@@ -220,7 +217,6 @@ def test_get_agents_search(
description="Specific search term description",
runs=75,
rating=4.2,
agent_graph_id="test-graph-search",
)
],
pagination=store_model.Pagination(
@@ -266,7 +262,6 @@ def test_get_agents_category(
description="Category agent description",
runs=60,
rating=4.1,
agent_graph_id="test-graph-category",
)
],
pagination=store_model.Pagination(
@@ -311,7 +306,6 @@ def test_get_agents_pagination(
description=f"Agent {i} description",
runs=i * 10,
rating=4.0,
agent_graph_id="test-graph-2",
)
for i in range(5)
],

View File

@@ -33,7 +33,6 @@ class TestCacheDeletion:
description="Test description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=Pagination(

View File

@@ -40,10 +40,6 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
@@ -122,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:

View File

@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
try:
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
await asyncio.gather(execution_worker(), notification_worker())
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -32,7 +32,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -279,6 +280,9 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
), # claude-3-haiku-20240307
@@ -634,18 +638,11 @@ async def llm_call(
context_window = llm_model.context_window
if compress_prompt_to_fit:
result = await compress_context(
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
client=None, # Truncation-only, no LLM summarization
reserve=0, # Caller handles response token budget separately
lossy_ok=True,
)
if result.error:
logger.warning(
f"Prompt compression did not meet target: {result.error}. "
f"Proceeding with {result.token_count} tokens."
)
prompt = result.messages
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)

View File

@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@property
def provider_name(self) -> str:
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -182,7 +182,10 @@ class StagehandObserveBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
@@ -227,7 +230,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -279,7 +282,10 @@ class StagehandActBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
@@ -324,7 +330,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -364,7 +370,10 @@ class StagehandExtractBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(

View File

@@ -873,13 +873,14 @@ def is_block_auth_configured(
async def initialize_blocks() -> None:
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs
from backend.util.retry import func_retry
sync_all_provider_costs()
@func_retry
async def sync_block_to_db(block: Block) -> None:
for cls in get_blocks().values():
block = cls()
existing_block = await AgentBlock.prisma().find_first(
where={"OR": [{"id": block.id}, {"name": block.name}]}
)
@@ -892,7 +893,7 @@ async def initialize_blocks() -> None:
outputSchema=json.dumps(block.output_schema.jsonschema()),
)
)
return
continue
input_schema = json.dumps(block.input_schema.jsonschema())
output_schema = json.dumps(block.output_schema.jsonschema())
@@ -912,25 +913,6 @@ async def initialize_blocks() -> None:
},
)
failed_blocks: list[str] = []
for cls in get_blocks().values():
block = cls()
try:
await sync_block_to_db(block)
except Exception as e:
logger.warning(
f"Failed to sync block {block.name} to database: {e}. "
"Block is still available in memory.",
exc_info=True,
)
failed_blocks.append(block.name)
if failed_blocks:
logger.error(
f"Failed to sync {len(failed_blocks)} block(s) to database: "
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
)
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> AnyBlockSchema | None:

View File

@@ -81,6 +81,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,

View File

@@ -133,23 +133,10 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
def __init__(self):
self._pubsub: AsyncPubSub | None = None
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()
async def close(self) -> None:
"""Close the PubSub connection if it exists."""
if self._pubsub is not None:
try:
await self._pubsub.close()
except Exception:
logger.warning("Failed to close PubSub connection", exc_info=True)
finally:
self._pubsub = None
async def publish_event(self, event: M, channel_key: str):
"""
Publish an event to Redis. Gracefully handles connection failures
@@ -170,7 +157,6 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
await self.connection, channel_key
)
assert isinstance(pubsub, AsyncPubSub)
self._pubsub = pubsub
if "*" in channel_key:
await pubsub.psubscribe(full_channel_name)

View File

@@ -1028,39 +1028,6 @@ async def get_graph(
return GraphModel.from_db(graph, for_export)
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
"""Batch-fetch multiple store-listed graphs by their IDs.
Only returns graphs that have approved store listings (publicly available).
Does not require permission checks since store-listed graphs are public.
Args:
*graph_ids: Variable number of graph IDs to fetch
Returns:
Dict mapping graph_id to GraphModel for graphs with approved store listings
"""
if not graph_ids:
return {}
store_listings = await StoreListingVersion.prisma().find_many(
where={
"agentGraphId": {"in": list(graph_ids)},
"submissionStatus": SubmissionStatus.APPROVED,
"isDeleted": False,
},
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
distinct=["agentGraphId"],
order={"agentGraphVersion": "desc"},
)
return {
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
for listing in store_listings
if listing.AgentGraph
}
async def get_graph_as_admin(
graph_id: str,
version: int | None = None,

View File

@@ -666,16 +666,10 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if not (self.discriminator and self.discriminator_mapping):
return self
try:
provider = self.discriminator_mapping[discriminator_value]
except KeyError:
raise ValueError(
f"Model '{discriminator_value}' is not supported. "
"It may have been deprecated. Please update your agent configuration."
)
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_provider=frozenset(
[self.discriminator_mapping[discriminator_value]]
),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,

View File

@@ -17,7 +17,6 @@ from backend.data.analytics import (
get_accuracy_trends_and_alerts,
get_marketplace_graphs_for_monitoring,
)
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
@@ -220,9 +219,6 @@ class DatabaseManager(AppService):
# Onboarding
increment_onboarding_runs = _(increment_onboarding_runs)
# OAuth
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
# Store
get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details)
@@ -353,9 +349,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# Onboarding
increment_onboarding_runs = d.increment_onboarding_runs
# OAuth
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
# Store
get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details

View File

@@ -24,9 +24,11 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_onboarding_runs
from backend.executor import utils as execution_utils
from backend.monitoring import (
NotificationJobArgs,
@@ -36,11 +38,7 @@ from backend.monitoring import (
report_execution_accuracy_alerts,
report_late_executions,
)
from backend.util.clients import (
get_database_manager_async_client,
get_database_manager_client,
get_scheduler_client,
)
from backend.util.clients import get_database_manager_client, get_scheduler_client
from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import (
GraphNotFoundError,
@@ -150,7 +148,6 @@ def execute_graph(**kwargs):
async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
start_time = asyncio.get_event_loop().time()
db = get_database_manager_async_client()
try:
logger.info(f"Executing recurring job for graph #{args.graph_id}")
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
@@ -160,7 +157,7 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
await db.increment_onboarding_runs(args.user_id)
await increment_onboarding_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -249,13 +246,8 @@ def cleanup_expired_files():
def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database."""
# Wait for completion
async def _cleanup():
db = get_database_manager_async_client()
return await db.cleanup_expired_oauth_tokens()
run_async(_cleanup())
run_async(cleanup_expired_oauth_tokens())
def execution_accuracy_alerts():

View File

@@ -1,39 +0,0 @@
from urllib.parse import urlparse
import fastapi
from fastapi.routing import APIRoute
from backend.api.features.integrations.router import router as integrations_router
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import utils as webhooks_utils
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
app = fastapi.FastAPI()
app.include_router(integrations_router, prefix="/api/integrations")
provider = ProviderName.GITHUB
webhook_id = "webhook_123"
base_url = "https://example.com"
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
route = next(
route
for route in integrations_router.routes
if isinstance(route, APIRoute)
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
and "POST" in route.methods
)
expected_path = f"/api/integrations{route.path}".format(
provider=provider.value,
webhook_id=webhook_id,
)
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
expected_base = urlparse(base_url)
assert (actual_url.scheme, actual_url.netloc) == (
expected_base.scheme,
expected_base.netloc,
)
assert actual_url.path == expected_path

View File

@@ -1,19 +1,10 @@
from __future__ import annotations
import logging
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import Any
from tiktoken import encoding_for_model
from backend.util import json
if TYPE_CHECKING:
from openai import AsyncOpenAI
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------#
# CONSTANTS #
# ---------------------------------------------------------------------------#
@@ -109,17 +100,9 @@ def _is_objective_message(msg: dict) -> bool:
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
"""
Carefully truncate tool message content while preserving tool structure.
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
Only truncates tool_result content, leaves tool_use intact.
"""
content = msg.get("content")
# OpenAI-style tool message: role="tool" with string content
if msg.get("role") == "tool" and isinstance(content, str):
if _tok_len(content, enc) > max_tokens:
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
return
# Anthropic-style: list content with tool_result items
if not isinstance(content, list):
return
@@ -157,6 +140,141 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
# ---------------------------------------------------------------------------#
def compress_prompt(
messages: list[dict],
target_tokens: int,
*,
model: str = "gpt-4o",
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
lossy_ok: bool = True,
) -> list[dict]:
"""
Shrink *messages* so that::
token_count(prompt) + reserve ≤ target_tokens
Strategy
--------
1. **Token-aware truncation** progressively halve a per-message cap
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
*content* of every message except the first and last. Tool shells
are included: we keep the envelope but shorten huge payloads.
2. **Middle-out deletion** if still over the limit, delete whole
messages working outward from the centre, **skipping** any message
that contains ``tool_calls`` or has ``role == "tool"``.
3. **Last-chance trim** if still too big, truncate the *first* and
*last* message bodies down to `floor_cap` tokens.
4. If the prompt is *still* too large:
• raise ``ValueError`` when ``lossy_ok == False`` (default)
• return the partially-trimmed prompt when ``lossy_ok == True``
Parameters
----------
messages Complete chat history (will be deep-copied).
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
target_tokens Hard ceiling for prompt size **excluding** the model's
forthcoming answer.
reserve How many tokens you want to leave available for that
answer (`max_tokens` in your subsequent completion call).
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap we'll accept before moving to deletions.
lossy_ok If *True* return best-effort prompt instead of raising
after all trim passes have been exhausted.
Returns
-------
list[dict] A *new* messages list that abides by the rules above.
"""
enc = encoding_for_model(model) # best-match tokenizer
msgs = deepcopy(messages) # never mutate caller
def total_tokens() -> int:
"""Current size of *msgs* in tokens."""
return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens:
return msgs
# ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
# Keep first and last messages intact (unless they're tool messages)
if i == 0 or i == len(msgs) - 1:
continue
# Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 1 : token-aware truncation ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact
if _is_tool_message(m):
# For tool messages, only truncate tool result content, preserve structure
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
# Never truncate objective messages - they contain the core task
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
deletable_indices = []
for i in range(1, len(msgs) - 1): # Skip first and last
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
deletable_indices.append(i)
if not deletable_indices:
break # nothing more we can drop
# Delete from center outward - find the index closest to center
centre = len(msgs) // 2
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
del msgs[to_delete]
# ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last
if _is_tool_message(msgs[idx]):
# For tool messages at first/last position, truncate tool result content only
_truncate_tool_message_content(msgs[idx], enc, cap)
continue
text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 4 : success or fail-gracefully -----------------------------
if total_tokens() + reserve > target_tokens and not lossy_ok:
raise ValueError(
"compress_prompt: prompt still exceeds budget "
f"({total_tokens() + reserve} > {target_tokens})."
)
return msgs
def estimate_token_count(
messages: list[dict],
*,
@@ -175,8 +293,7 @@ def estimate_token_count(
-------
int Token count.
"""
token_model = _normalize_model_for_tokenizer(model)
enc = encoding_for_model(token_model)
enc = encoding_for_model(model) # best-match tokenizer
return sum(_msg_tokens(m, enc) for m in messages)
@@ -198,543 +315,6 @@ def estimate_token_count_str(
-------
int Token count.
"""
token_model = _normalize_model_for_tokenizer(model)
enc = encoding_for_model(token_model)
enc = encoding_for_model(model) # best-match tokenizer
text = json.dumps(text) if not isinstance(text, str) else text
return _tok_len(text, enc)
# ---------------------------------------------------------------------------#
# UNIFIED CONTEXT COMPRESSION #
# ---------------------------------------------------------------------------#
# Default thresholds
DEFAULT_TOKEN_THRESHOLD = 120_000
DEFAULT_KEEP_RECENT = 15
@dataclass
class CompressResult:
"""Result of context compression."""
messages: list[dict]
token_count: int
was_compacted: bool
error: str | None = None
original_token_count: int = 0
messages_summarized: int = 0
messages_dropped: int = 0
def _normalize_model_for_tokenizer(model: str) -> str:
"""Normalize model name for tiktoken tokenizer selection."""
if "/" in model:
model = model.split("/")[-1]
if "claude" in model.lower() or not any(
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
):
return "gpt-4o"
return model
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs from an assistant message.
Supports both formats:
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
Returns:
Set of tool_call IDs found in the message.
"""
ids: set[str] = set()
if msg.get("role") != "assistant":
return ids
# OpenAI format: tool_calls array
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id")
if tc_id:
ids.add(tc_id)
# Anthropic format: content list with tool_use blocks
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tc_id = block.get("id")
if tc_id:
ids.add(tc_id)
return ids
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs that this message is responding to.
Supports both formats:
- OpenAI: {"role": "tool", "tool_call_id": "..."}
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
Returns:
Set of tool_call IDs this message responds to.
"""
ids: set[str] = set()
# OpenAI format: role=tool with tool_call_id
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id:
ids.add(tc_id)
# Anthropic format: content list with tool_result blocks
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
tc_id = block.get("tool_use_id")
if tc_id:
ids.add(tc_id)
return ids
def _is_tool_response_message(msg: dict) -> bool:
"""Check if message is a tool response (OpenAI or Anthropic format)."""
# OpenAI format
if msg.get("role") == "tool":
return True
# Anthropic format
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
return True
return False
def _remove_orphan_tool_responses(
messages: list[dict], orphan_ids: set[str]
) -> list[dict]:
"""
Remove tool response messages/blocks that reference orphan tool_call IDs.
Supports both OpenAI and Anthropic formats.
For Anthropic messages with mixed valid/orphan tool_result blocks,
filters out only the orphan blocks instead of dropping the entire message.
"""
result = []
for msg in messages:
# OpenAI format: role=tool - drop entire message if orphan
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id and tc_id in orphan_ids:
continue
result.append(msg)
continue
# Anthropic format: content list may have mixed tool_result blocks
content = msg.get("content")
if isinstance(content, list):
has_tool_results = any(
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
)
if has_tool_results:
# Filter out orphan tool_result blocks, keep valid ones
filtered_content = [
block
for block in content
if not (
isinstance(block, dict)
and block.get("type") == "tool_result"
and block.get("tool_use_id") in orphan_ids
)
]
# Only keep message if it has remaining content
if filtered_content:
msg = msg.copy()
msg["content"] = filtered_content
result.append(msg)
continue
result.append(msg)
return result
def _ensure_tool_pairs_intact(
recent_messages: list[dict],
all_messages: list[dict],
start_index: int,
) -> list[dict]:
"""
Ensure tool_call/tool_response pairs stay together after slicing.
When slicing messages for context compaction, a naive slice can separate
an assistant message containing tool_calls from its corresponding tool
response messages. This causes API validation errors (e.g., Anthropic's
"unexpected tool_use_id found in tool_result blocks").
This function checks for orphan tool responses in the slice and extends
backwards to include their corresponding assistant messages.
Supports both formats:
- OpenAI: tool_calls array + role="tool" responses
- Anthropic: tool_use blocks + tool_result blocks
Args:
recent_messages: The sliced messages to validate
all_messages: The complete message list (for looking up missing assistants)
start_index: The index in all_messages where recent_messages begins
Returns:
A potentially extended list of messages with tool pairs intact
"""
if not recent_messages:
return recent_messages
# Collect all tool_call_ids from assistant messages in the slice
available_tool_call_ids: set[str] = set()
for msg in recent_messages:
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
# Find orphan tool responses (responses whose tool_call_id is missing)
orphan_tool_call_ids: set[str] = set()
for msg in recent_messages:
response_ids = _extract_tool_response_ids_from_message(msg)
for tc_id in response_ids:
if tc_id not in available_tool_call_ids:
orphan_tool_call_ids.add(tc_id)
if not orphan_tool_call_ids:
# No orphans, slice is valid
return recent_messages
# Find the assistant messages that contain the orphan tool_call_ids
# Search backwards from start_index in all_messages
messages_to_prepend: list[dict] = []
for i in range(start_index - 1, -1, -1):
msg = all_messages[i]
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
if msg_tool_ids & orphan_tool_call_ids:
# This assistant message has tool_calls we need
# Also collect its contiguous tool responses that follow it
assistant_and_responses: list[dict] = [msg]
# Scan forward from this assistant to collect tool responses
for j in range(i + 1, start_index):
following_msg = all_messages[j]
following_response_ids = _extract_tool_response_ids_from_message(
following_msg
)
if following_response_ids and following_response_ids & msg_tool_ids:
assistant_and_responses.append(following_msg)
elif not _is_tool_response_message(following_msg):
# Stop at first non-tool-response message
break
# Prepend the assistant and its tool responses (maintain order)
messages_to_prepend = assistant_and_responses + messages_to_prepend
# Mark these as found
orphan_tool_call_ids -= msg_tool_ids
# Also add this assistant's tool_call_ids to available set
available_tool_call_ids |= msg_tool_ids
if not orphan_tool_call_ids:
# Found all missing assistants
break
if orphan_tool_call_ids:
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
)
recent_messages = _remove_orphan_tool_responses(
recent_messages, orphan_tool_call_ids
)
if messages_to_prepend:
logger.info(
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
)
return messages_to_prepend + recent_messages
return recent_messages
async def _summarize_messages_llm(
messages: list[dict],
client: AsyncOpenAI,
model: str,
timeout: float = 30.0,
) -> str:
"""Summarize messages using an LLM."""
conversation = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if content and role in ("user", "assistant", "tool"):
conversation.append(f"{role.upper()}: {content}")
conversation_text = "\n\n".join(conversation)
if not conversation_text:
return "No conversation history available."
# Limit to ~100k chars for safety
MAX_CHARS = 100_000
if len(conversation_text) > MAX_CHARS:
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
response = await client.with_options(timeout=timeout).chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": (
"Create a detailed summary of the conversation so far. "
"This summary will be used as context when continuing the conversation.\n\n"
"Before writing the summary, analyze each message chronologically to identify:\n"
"- User requests and their explicit goals\n"
"- Your approach and key decisions made\n"
"- Technical specifics (file names, tool outputs, function signatures)\n"
"- Errors encountered and resolutions applied\n\n"
"You MUST include ALL of the following sections:\n\n"
"## 1. Primary Request and Intent\n"
"The user's explicit goals and what they are trying to accomplish.\n\n"
"## 2. Key Technical Concepts\n"
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
"## 3. Files and Resources Involved\n"
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
"## 4. Errors and Fixes\n"
"Problems encountered, error messages, and their resolutions. "
"Include any user feedback on fixes.\n\n"
"## 5. Problem Solving\n"
"Issues that have been resolved and how they were addressed.\n\n"
"## 6. All User Messages\n"
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
"## 7. Pending Tasks\n"
"Work items the user explicitly requested that have not yet been completed.\n\n"
"## 8. Current Work\n"
"Precise description of what was being worked on most recently, including relevant context.\n\n"
"## 9. Next Steps\n"
"What should happen next, aligned with the user's most recent requests. "
"Include verbatim quotes of recent instructions if relevant."
),
},
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
],
max_tokens=1500,
temperature=0.3,
)
return response.choices[0].message.content or "No summary available."
async def compress_context(
messages: list[dict],
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
*,
model: str = "gpt-4o",
client: AsyncOpenAI | None = None,
keep_recent: int = DEFAULT_KEEP_RECENT,
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
) -> CompressResult:
"""
Unified context compression that combines summarization and truncation strategies.
Strategy (in order):
1. **LLM summarization** If client provided, summarize old messages into a
single context message while keeping recent messages intact. This is the
primary strategy for chat service.
2. **Content truncation** Progressively halve a per-message cap and truncate
bloated message content (tool outputs, large pastes). Preserves all messages
but shortens their content. Primary strategy when client=None (LLM blocks).
3. **Middle-out deletion** Delete whole messages one at a time from the center
outward, skipping tool messages and objective messages.
4. **First/last trim** Truncate first and last message content as last resort.
Parameters
----------
messages Complete chat history (will be deep-copied).
target_tokens Hard ceiling for prompt size.
model Model name for tokenization and summarization.
client AsyncOpenAI client. If provided, enables LLM summarization
as the first strategy. If None, skips to truncation strategies.
keep_recent Number of recent messages to preserve during summarization.
reserve Tokens to reserve for model response.
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap before moving to deletions.
Returns
-------
CompressResult with compressed messages and metadata.
"""
# Guard clause for empty messages
if not messages:
return CompressResult(
messages=[],
token_count=0,
was_compacted=False,
original_token_count=0,
)
token_model = _normalize_model_for_tokenizer(model)
enc = encoding_for_model(token_model)
msgs = deepcopy(messages)
def total_tokens() -> int:
return sum(_msg_tokens(m, enc) for m in msgs)
original_count = total_tokens()
# Already under limit
if original_count + reserve <= target_tokens:
return CompressResult(
messages=msgs,
token_count=original_count,
was_compacted=False,
original_token_count=original_count,
)
messages_summarized = 0
messages_dropped = 0
# ---- STEP 1: LLM summarization (if client provided) -------------------
# This is the primary compression strategy for chat service.
# Summarize old messages while keeping recent ones intact.
if client is not None:
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
system_msg = msgs[0] if has_system else None
# Calculate old vs recent messages
if has_system:
if len(msgs) > keep_recent + 1:
old_msgs = msgs[1:-keep_recent]
recent_msgs = msgs[-keep_recent:]
else:
old_msgs = []
recent_msgs = msgs[1:] if len(msgs) > 1 else []
else:
if len(msgs) > keep_recent:
old_msgs = msgs[:-keep_recent]
recent_msgs = msgs[-keep_recent:]
else:
old_msgs = []
recent_msgs = msgs
# Ensure tool pairs stay intact
slice_start = max(0, len(msgs) - keep_recent)
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
if old_msgs:
try:
summary_text = await _summarize_messages_llm(old_msgs, client, model)
summary_msg = {
"role": "assistant",
"content": f"[Previous conversation summary — for context only]: {summary_text}",
}
messages_summarized = len(old_msgs)
if has_system:
msgs = [system_msg, summary_msg] + recent_msgs
else:
msgs = [summary_msg] + recent_msgs
logger.info(
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
)
except Exception as e:
logger.warning(f"Summarization failed, continuing with truncation: {e}")
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
# Always run this before truncation to ensure consistent token counting.
for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
if i == 0 or i == len(msgs) - 1:
continue
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 3: Token-aware content truncation ---------------------------
# Progressively halve per-message cap and truncate bloated content.
# This preserves all messages but shortens their content.
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]:
if _is_tool_message(m):
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2
# ---- STEP 4: Middle-out deletion --------------------------------------
# Delete messages one at a time from the center outward.
# This is more granular than dropping all old messages at once.
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
deletable: list[int] = []
for i in range(1, len(msgs) - 1):
msg = msgs[i]
if (
msg is not None
and not _is_tool_message(msg)
and not _is_objective_message(msg)
):
deletable.append(i)
if not deletable:
break
centre = len(msgs) // 2
to_delete = min(deletable, key=lambda i: abs(i - centre))
del msgs[to_delete]
messages_dropped += 1
# ---- STEP 5: Final trim on first/last ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1):
msg = msgs[idx]
if msg is None:
continue
if _is_tool_message(msg):
_truncate_tool_message_content(msg, enc, cap)
continue
text = msg.get("content") or ""
if _tok_len(text, enc) > cap:
msg["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2
# Filter out any None values that may have been introduced
final_msgs: list[dict] = [m for m in msgs if m is not None]
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
error = None
if final_count + reserve > target_tokens:
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
logger.warning(error)
return CompressResult(
messages=final_msgs,
token_count=final_count,
was_compacted=True,
error=error,
original_token_count=original_count,
messages_summarized=messages_summarized,
messages_dropped=messages_dropped,
)

View File

@@ -1,21 +1,10 @@
"""Tests for prompt utility functions, especially tool call token counting."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from tiktoken import encoding_for_model
from backend.util import json
from backend.util.prompt import (
CompressResult,
_ensure_tool_pairs_intact,
_msg_tokens,
_normalize_model_for_tokenizer,
_truncate_middle_tokens,
_truncate_tool_message_content,
compress_context,
estimate_token_count,
)
from backend.util.prompt import _msg_tokens, estimate_token_count
class TestMsgTokens:
@@ -287,690 +276,3 @@ class TestEstimateTokenCount:
assert total_tokens == expected_total
assert total_tokens > 20 # Should be substantial
class TestNormalizeModelForTokenizer:
"""Test model name normalization for tiktoken."""
def test_openai_models_unchanged(self):
"""Test that OpenAI models are returned as-is."""
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
def test_claude_models_normalized(self):
"""Test that Claude models are normalized to gpt-4o."""
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
def test_openrouter_paths_extracted(self):
"""Test that OpenRouter model paths are handled."""
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
def test_unknown_models_default_to_gpt4o(self):
"""Test that unknown models default to gpt-4o."""
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
class TestTruncateToolMessageContent:
"""Test tool message content truncation."""
@pytest.fixture
def enc(self):
return encoding_for_model("gpt-4o")
def test_truncate_openai_tool_message(self, enc):
"""Test truncation of OpenAI-style tool message with string content."""
long_content = "x" * 10000
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
_truncate_tool_message_content(msg, enc, max_tokens=100)
# Content should be truncated
assert len(msg["content"]) < len(long_content)
assert "" in msg["content"] # Has ellipsis marker
def test_truncate_anthropic_tool_result(self, enc):
"""Test truncation of Anthropic-style tool_result."""
long_content = "y" * 10000
msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": long_content,
}
],
}
_truncate_tool_message_content(msg, enc, max_tokens=100)
# Content should be truncated
result_content = msg["content"][0]["content"]
assert len(result_content) < len(long_content)
assert "" in result_content
def test_preserve_tool_use_blocks(self, enc):
"""Test that tool_use blocks are not truncated."""
msg = {
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_123",
"name": "some_function",
"input": {"key": "value" * 1000}, # Large input
}
],
}
original = json.dumps(msg["content"][0]["input"])
_truncate_tool_message_content(msg, enc, max_tokens=10)
# tool_use should be unchanged
assert json.dumps(msg["content"][0]["input"]) == original
def test_no_truncation_when_under_limit(self, enc):
"""Test that short content is not modified."""
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
original = msg["content"]
_truncate_tool_message_content(msg, enc, max_tokens=1000)
assert msg["content"] == original
class TestTruncateMiddleTokens:
"""Test middle truncation of text."""
@pytest.fixture
def enc(self):
return encoding_for_model("gpt-4o")
def test_truncates_long_text(self, enc):
"""Test that long text is truncated with ellipsis in middle."""
long_text = "word " * 1000
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
assert "" in result
assert result.startswith("word") # Head preserved
assert result.endswith("word ") # Tail preserved
def test_preserves_short_text(self, enc):
"""Test that short text is not modified."""
short_text = "Hello world"
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
assert result == short_text
class TestEnsureToolPairsIntact:
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
# ---- OpenAI Format Tests ----
def test_openai_adds_missing_tool_call(self):
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool response)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_call message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert "tool_calls" in result[0]
def test_openai_keeps_complete_pairs(self):
"""Test that complete OpenAI pairs are unchanged."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
]
recent = all_msgs[1:] # Include both tool_call and response
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert len(result) == 2 # No messages added
def test_openai_multiple_tool_calls(self):
"""Test multiple OpenAI tool calls in one assistant message."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (first tool response)
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with both tool_calls
assert len(result) == 4
assert result[0]["role"] == "assistant"
assert len(result[0]["tool_calls"]) == 2
# ---- Anthropic Format Tests ----
def test_anthropic_adds_missing_tool_use(self):
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_123",
"name": "get_weather",
"input": {"location": "SF"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": "22°C and sunny",
}
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool_result)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_use message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert result[0]["content"][0]["type"] == "tool_use"
def test_anthropic_keeps_complete_pairs(self):
"""Test that complete Anthropic pairs are unchanged."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_456",
"name": "calculator",
"input": {"expr": "2+2"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_456",
"content": "4",
}
],
},
]
recent = all_msgs[1:] # Include both tool_use and result
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert len(result) == 2 # No messages added
def test_anthropic_multiple_tool_uses(self):
"""Test multiple Anthropic tool_use blocks in one message."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me check both..."},
{
"type": "tool_use",
"id": "toolu_1",
"name": "get_weather",
"input": {"city": "NYC"},
},
{
"type": "tool_use",
"id": "toolu_2",
"name": "get_weather",
"input": {"city": "LA"},
},
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_1",
"content": "Cold",
},
{
"type": "tool_result",
"tool_use_id": "toolu_2",
"content": "Warm",
},
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (tool_result)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with both tool_uses
assert len(result) == 3
assert result[0]["role"] == "assistant"
tool_use_count = sum(
1 for b in result[0]["content"] if b.get("type") == "tool_use"
)
assert tool_use_count == 2
# ---- Mixed/Edge Case Tests ----
def test_anthropic_with_type_message_field(self):
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_abc",
"name": "search",
"input": {"q": "test"},
}
],
},
{
"role": "user",
"type": "message", # Extra field from smart_decision_maker
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_abc",
"content": "Found results",
}
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool_result with 'type': 'message')
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_use message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert result[0]["content"][0]["type"] == "tool_use"
def test_handles_no_tool_messages(self):
"""Test messages without tool calls."""
all_msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
recent = all_msgs
start_index = 0
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert result == all_msgs
def test_handles_empty_messages(self):
"""Test empty message list."""
result = _ensure_tool_pairs_intact([], [], 0)
assert result == []
def test_mixed_text_and_tool_content(self):
"""Test Anthropic message with mixed text and tool_use content."""
all_msgs = [
{
"role": "assistant",
"content": [
{"type": "text", "text": "I'll help you with that."},
{
"type": "tool_use",
"id": "toolu_mixed",
"name": "search",
"input": {"q": "test"},
},
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_mixed",
"content": "Found results",
}
],
},
{"role": "assistant", "content": "Here are the results..."},
]
# Start from tool_result
recent = [all_msgs[1], all_msgs[2]]
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with tool_use
assert len(result) == 3
assert result[0]["content"][0]["type"] == "text"
assert result[0]["content"][1]["type"] == "tool_use"
class TestCompressContext:
"""Test the async compress_context function."""
@pytest.mark.asyncio
async def test_no_compression_needed(self):
"""Test messages under limit return without compression."""
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello!"},
]
result = await compress_context(messages, target_tokens=100000)
assert isinstance(result, CompressResult)
assert result.was_compacted is False
assert len(result.messages) == 2
assert result.error is None
@pytest.mark.asyncio
async def test_truncation_without_client(self):
"""Test that truncation works without LLM client."""
long_content = "x" * 50000
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": long_content},
{"role": "assistant", "content": "Response"},
]
result = await compress_context(
messages, target_tokens=1000, client=None, reserve=100
)
assert result.was_compacted is True
# Should have truncated without summarization
assert result.messages_summarized == 0
@pytest.mark.asyncio
async def test_with_mocked_llm_client(self):
"""Test summarization with mocked LLM client."""
# Create many messages to trigger summarization
messages = [{"role": "system", "content": "System prompt"}]
for i in range(30):
messages.append({"role": "user", "content": f"User message {i} " * 100})
messages.append(
{"role": "assistant", "content": f"Assistant response {i} " * 100}
)
# Mock the AsyncOpenAI client
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Summary of conversation"
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)
result = await compress_context(
messages,
target_tokens=5000,
client=mock_client,
keep_recent=5,
reserve=500,
)
assert result.was_compacted is True
# Should have attempted summarization
assert mock_client.with_options.called or result.messages_summarized > 0
@pytest.mark.asyncio
async def test_preserves_tool_pairs(self):
"""Test that tool call/response pairs stay together."""
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": "Do something"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "func"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
{"role": "assistant", "content": "Done!"},
]
result = await compress_context(
messages, target_tokens=500, client=None, reserve=50
)
# Check that if tool response exists, its call exists too
tool_call_ids = set()
tool_response_ids = set()
for msg in result.messages:
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
tool_call_ids.add(tc["id"])
if msg.get("role") == "tool":
tool_response_ids.add(msg.get("tool_call_id"))
# All tool responses should have their calls
assert tool_response_ids <= tool_call_ids
@pytest.mark.asyncio
async def test_returns_error_when_cannot_compress(self):
"""Test that error is returned when compression fails."""
# Single huge message that can't be compressed enough
messages = [
{"role": "user", "content": "x" * 100000},
]
result = await compress_context(
messages, target_tokens=100, client=None, reserve=50
)
# Should have an error since we can't get below 100 tokens
assert result.error is not None
assert result.was_compacted is True
@pytest.mark.asyncio
async def test_empty_messages(self):
"""Test that empty messages list returns early without error."""
result = await compress_context([], target_tokens=1000)
assert result.messages == []
assert result.token_count == 0
assert result.was_compacted is False
assert result.error is None
class TestRemoveOrphanToolResponses:
"""Test _remove_orphan_tool_responses helper function."""
def test_removes_openai_orphan(self):
"""Test removal of orphan OpenAI tool response."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
{"role": "user", "content": "Hello"},
]
orphan_ids = {"call_orphan"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_keeps_valid_openai_tool(self):
"""Test that valid OpenAI tool responses are kept."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
]
orphan_ids = {"call_other"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_valid"
def test_filters_anthropic_mixed_blocks(self):
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_valid",
"content": "valid result",
},
{
"type": "tool_result",
"tool_use_id": "toolu_orphan",
"content": "orphan result",
},
],
},
]
orphan_ids = {"toolu_orphan"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
# Should only have the valid tool_result, orphan filtered out
assert len(result[0]["content"]) == 1
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
def test_removes_anthropic_all_orphan(self):
"""Test removal of Anthropic message when all tool_results are orphans."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_orphan1",
"content": "result1",
},
{
"type": "tool_result",
"tool_use_id": "toolu_orphan2",
"content": "result2",
},
],
},
]
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
# Message should be completely removed since no content left
assert len(result) == 0
def test_preserves_non_tool_messages(self):
"""Test that non-tool messages are preserved."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
orphan_ids = {"some_id"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert result == messages
class TestCompressResultDataclass:
"""Test CompressResult dataclass."""
def test_default_values(self):
"""Test default values are set correctly."""
result = CompressResult(
messages=[{"role": "user", "content": "test"}],
token_count=10,
was_compacted=False,
)
assert result.error is None
assert result.original_token_count == 0 # Defaults to 0, not None
assert result.messages_summarized == 0
assert result.messages_dropped == 0
def test_all_fields(self):
"""Test all fields can be set."""
result = CompressResult(
messages=[{"role": "user", "content": "test"}],
token_count=100,
was_compacted=True,
error="Some error",
original_token_count=500,
messages_summarized=10,
messages_dropped=5,
)
assert result.token_count == 100
assert result.was_compacted is True
assert result.error == "Some error"
assert result.original_token_count == 500
assert result.messages_summarized == 10
assert result.messages_dropped == 5

View File

@@ -1,22 +0,0 @@
-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet
-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model
-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026
-- Update AgentNode constant inputs
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
'"claude-sonnet-4-5-20250929"'::jsonb
)
WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
UPDATE "AgentNodeExecutionInputOutput"
SET "data" = JSONB_SET(
"data"::jsonb,
'{model}',
'"claude-sonnet-4-5-20250929"'::jsonb
)
WHERE "agentPresetId" IS NOT NULL
AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';

View File

@@ -9,8 +9,7 @@
"sub_heading": "Creator agent subheading",
"description": "Creator agent description",
"runs": 50,
"rating": 4.0,
"agent_graph_id": "test-graph-2"
"rating": 4.0
}
],
"pagination": {

View File

@@ -9,8 +9,7 @@
"sub_heading": "Category agent subheading",
"description": "Category agent description",
"runs": 60,
"rating": 4.1,
"agent_graph_id": "test-graph-category"
"rating": 4.1
}
],
"pagination": {

View File

@@ -9,8 +9,7 @@
"sub_heading": "Agent 0 subheading",
"description": "Agent 0 description",
"runs": 0,
"rating": 4.0,
"agent_graph_id": "test-graph-2"
"rating": 4.0
},
{
"slug": "agent-1",
@@ -21,8 +20,7 @@
"sub_heading": "Agent 1 subheading",
"description": "Agent 1 description",
"runs": 10,
"rating": 4.0,
"agent_graph_id": "test-graph-2"
"rating": 4.0
},
{
"slug": "agent-2",
@@ -33,8 +31,7 @@
"sub_heading": "Agent 2 subheading",
"description": "Agent 2 description",
"runs": 20,
"rating": 4.0,
"agent_graph_id": "test-graph-2"
"rating": 4.0
},
{
"slug": "agent-3",
@@ -45,8 +42,7 @@
"sub_heading": "Agent 3 subheading",
"description": "Agent 3 description",
"runs": 30,
"rating": 4.0,
"agent_graph_id": "test-graph-2"
"rating": 4.0
},
{
"slug": "agent-4",
@@ -57,8 +53,7 @@
"sub_heading": "Agent 4 subheading",
"description": "Agent 4 description",
"runs": 40,
"rating": 4.0,
"agent_graph_id": "test-graph-2"
"rating": 4.0
}
],
"pagination": {

View File

@@ -9,8 +9,7 @@
"sub_heading": "Search agent subheading",
"description": "Specific search term description",
"runs": 75,
"rating": 4.2,
"agent_graph_id": "test-graph-search"
"rating": 4.2
}
],
"pagination": {

View File

@@ -9,8 +9,7 @@
"sub_heading": "Top agent subheading",
"description": "Top agent description",
"runs": 1000,
"rating": 5.0,
"agent_graph_id": "test-graph-3"
"rating": 5.0
}
],
"pagination": {

View File

@@ -9,8 +9,7 @@
"sub_heading": "Featured agent subheading",
"description": "Featured agent description",
"runs": 100,
"rating": 4.5,
"agent_graph_id": "test-graph-1"
"rating": 4.5
}
],
"pagination": {

View File

@@ -31,10 +31,6 @@
"has_sensitive_action": false,
"trigger_setup_info": null,
"new_output": false,
"execution_count": 0,
"success_rate": null,
"avg_correctness_score": null,
"recent_executions": [],
"can_access_graph": true,
"is_latest_version": true,
"is_favorite": false,
@@ -76,10 +72,6 @@
"has_sensitive_action": false,
"trigger_setup_info": null,
"new_output": false,
"execution_count": 0,
"success_rate": null,
"avg_correctness_score": null,
"recent_executions": [],
"can_access_graph": false,
"is_latest_version": true,
"is_favorite": false,

View File

@@ -57,8 +57,7 @@ class TestDecomposeGoal:
result = await core.decompose_goal("Build a chatbot")
# library_agents defaults to None
mock_external.assert_called_once_with("Build a chatbot", "", None)
mock_external.assert_called_once_with("Build a chatbot", "")
assert result == expected_result
@pytest.mark.asyncio
@@ -75,8 +74,7 @@ class TestDecomposeGoal:
await core.decompose_goal("Build a chatbot", "Use Python")
# library_agents defaults to None
mock_external.assert_called_once_with("Build a chatbot", "Use Python", None)
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
@pytest.mark.asyncio
async def test_returns_none_on_service_failure(self):
@@ -111,7 +109,8 @@ class TestGenerateAgent:
instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await core.generate_agent(instructions)
mock_external.assert_called_once_with(instructions, None, None, None)
mock_external.assert_called_once_with(instructions)
# Result should have id, version, is_active added if not present
assert result is not None
assert result["name"] == "Test Agent"
assert "id" in result
@@ -175,9 +174,7 @@ class TestGenerateAgentPatch:
current_agent = {"nodes": [], "links": []}
result = await core.generate_agent_patch("Add a node", current_agent)
mock_external.assert_called_once_with(
"Add a node", current_agent, None, None, None
)
mock_external.assert_called_once_with("Add a node", current_agent)
assert result == expected_result
@pytest.mark.asyncio

View File

@@ -1,857 +0,0 @@
"""
Tests for library agent fetching functionality in agent generator.
This test suite verifies the search-based library agent fetching,
including the combination of library and marketplace agents.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.agent_generator import core
class TestGetLibraryAgentsForGeneration:
"""Test get_library_agents_for_generation function."""
@pytest.mark.asyncio
async def test_fetches_agents_with_search_term(self):
"""Test that search_term is passed to the library db."""
# Create a mock agent with proper attribute values
mock_agent = MagicMock()
mock_agent.graph_id = "agent-123"
mock_agent.graph_version = 1
mock_agent.name = "Email Agent"
mock_agent.description = "Sends emails"
mock_agent.input_schema = {"properties": {}}
mock_agent.output_schema = {"properties": {}}
mock_agent.recent_executions = []
mock_response = MagicMock()
mock_response.agents = [mock_agent]
with patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_list:
result = await core.get_library_agents_for_generation(
user_id="user-123",
search_query="send email",
)
mock_list.assert_called_once_with(
user_id="user-123",
search_term="send email",
page=1,
page_size=15,
include_executions=True,
)
# Verify result format
assert len(result) == 1
assert result[0]["graph_id"] == "agent-123"
assert result[0]["name"] == "Email Agent"
@pytest.mark.asyncio
async def test_excludes_specified_graph_id(self):
"""Test that agents with excluded graph_id are filtered out."""
mock_response = MagicMock()
mock_response.agents = [
MagicMock(
graph_id="agent-123",
graph_version=1,
name="Agent 1",
description="First agent",
input_schema={},
output_schema={},
recent_executions=[],
),
MagicMock(
graph_id="agent-456",
graph_version=1,
name="Agent 2",
description="Second agent",
input_schema={},
output_schema={},
recent_executions=[],
),
]
with patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
):
result = await core.get_library_agents_for_generation(
user_id="user-123",
exclude_graph_id="agent-123",
)
# Verify the excluded agent is not in results
assert len(result) == 1
assert result[0]["graph_id"] == "agent-456"
@pytest.mark.asyncio
async def test_respects_max_results(self):
"""Test that max_results parameter limits the page_size."""
mock_response = MagicMock()
mock_response.agents = []
with patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_list:
await core.get_library_agents_for_generation(
user_id="user-123",
max_results=5,
)
mock_list.assert_called_once_with(
user_id="user-123",
search_term=None,
page=1,
page_size=5,
include_executions=True,
)
class TestSearchMarketplaceAgentsForGeneration:
"""Test search_marketplace_agents_for_generation function."""
@pytest.mark.asyncio
async def test_searches_marketplace_with_query(self):
"""Test that marketplace is searched with the query."""
mock_response = MagicMock()
mock_response.agents = [
MagicMock(
agent_name="Public Agent",
description="A public agent",
sub_heading="Does something useful",
creator="creator-1",
agent_graph_id="graph-123",
)
]
mock_graph = MagicMock()
mock_graph.id = "graph-123"
mock_graph.version = 1
mock_graph.input_schema = {"type": "object"}
mock_graph.output_schema = {"type": "object"}
with (
patch(
"backend.api.features.store.db.get_store_agents",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_search,
patch(
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
new_callable=AsyncMock,
return_value={"graph-123": mock_graph},
),
):
result = await core.search_marketplace_agents_for_generation(
search_query="automation",
max_results=10,
)
mock_search.assert_called_once_with(
search_query="automation",
page=1,
page_size=10,
)
assert len(result) == 1
assert result[0]["name"] == "Public Agent"
assert result[0]["graph_id"] == "graph-123"
@pytest.mark.asyncio
async def test_handles_marketplace_error_gracefully(self):
"""Test that marketplace errors don't crash the function."""
with patch(
"backend.api.features.store.db.get_store_agents",
new_callable=AsyncMock,
side_effect=Exception("Marketplace unavailable"),
):
result = await core.search_marketplace_agents_for_generation(
search_query="test"
)
# Should return empty list, not raise exception
assert result == []
class TestGetAllRelevantAgentsForGeneration:
"""Test get_all_relevant_agents_for_generation function."""
@pytest.mark.asyncio
async def test_combines_library_and_marketplace_agents(self):
"""Test that agents from both sources are combined."""
library_agents = [
{
"graph_id": "lib-123",
"graph_version": 1,
"name": "Library Agent",
"description": "From library",
"input_schema": {},
"output_schema": {},
}
]
marketplace_agents = [
{
"graph_id": "market-456",
"graph_version": 1,
"name": "Market Agent",
"description": "From marketplace",
"input_schema": {},
"output_schema": {},
}
]
with patch.object(
core,
"get_library_agents_for_generation",
new_callable=AsyncMock,
return_value=library_agents,
):
with patch.object(
core,
"search_marketplace_agents_for_generation",
new_callable=AsyncMock,
return_value=marketplace_agents,
):
result = await core.get_all_relevant_agents_for_generation(
user_id="user-123",
search_query="test query",
include_marketplace=True,
)
# Library agents should come first
assert len(result) == 2
assert result[0]["name"] == "Library Agent"
assert result[1]["name"] == "Market Agent"
@pytest.mark.asyncio
async def test_deduplicates_by_graph_id(self):
"""Test that marketplace agents with same graph_id as library are excluded."""
library_agents = [
{
"graph_id": "shared-123",
"graph_version": 1,
"name": "Shared Agent",
"description": "From library",
"input_schema": {},
"output_schema": {},
}
]
marketplace_agents = [
{
"graph_id": "shared-123", # Same graph_id, should be deduplicated
"graph_version": 1,
"name": "Shared Agent",
"description": "From marketplace",
"input_schema": {},
"output_schema": {},
},
{
"graph_id": "unique-456",
"graph_version": 1,
"name": "Unique Agent",
"description": "Only in marketplace",
"input_schema": {},
"output_schema": {},
},
]
with patch.object(
core,
"get_library_agents_for_generation",
new_callable=AsyncMock,
return_value=library_agents,
):
with patch.object(
core,
"search_marketplace_agents_for_generation",
new_callable=AsyncMock,
return_value=marketplace_agents,
):
result = await core.get_all_relevant_agents_for_generation(
user_id="user-123",
search_query="test",
include_marketplace=True,
)
# Shared Agent from marketplace should be excluded by graph_id
assert len(result) == 2
names = [a["name"] for a in result]
assert "Shared Agent" in names
assert "Unique Agent" in names
@pytest.mark.asyncio
async def test_skips_marketplace_when_disabled(self):
"""Test that marketplace is not searched when include_marketplace=False."""
library_agents = [
{
"graph_id": "lib-123",
"graph_version": 1,
"name": "Library Agent",
"description": "From library",
"input_schema": {},
"output_schema": {},
}
]
with patch.object(
core,
"get_library_agents_for_generation",
new_callable=AsyncMock,
return_value=library_agents,
):
with patch.object(
core,
"search_marketplace_agents_for_generation",
new_callable=AsyncMock,
) as mock_marketplace:
result = await core.get_all_relevant_agents_for_generation(
user_id="user-123",
search_query="test",
include_marketplace=False,
)
# Marketplace should not be called
mock_marketplace.assert_not_called()
assert len(result) == 1
@pytest.mark.asyncio
async def test_skips_marketplace_when_no_search_query(self):
"""Test that marketplace is not searched without a search query."""
library_agents = [
{
"graph_id": "lib-123",
"graph_version": 1,
"name": "Library Agent",
"description": "From library",
"input_schema": {},
"output_schema": {},
}
]
with patch.object(
core,
"get_library_agents_for_generation",
new_callable=AsyncMock,
return_value=library_agents,
):
with patch.object(
core,
"search_marketplace_agents_for_generation",
new_callable=AsyncMock,
) as mock_marketplace:
result = await core.get_all_relevant_agents_for_generation(
user_id="user-123",
search_query=None, # No search query
include_marketplace=True,
)
# Marketplace should not be called without search query
mock_marketplace.assert_not_called()
assert len(result) == 1
class TestExtractSearchTermsFromSteps:
"""Test extract_search_terms_from_steps function."""
def test_extracts_terms_from_instructions_type(self):
"""Test extraction from valid instructions decomposition result."""
decomposition_result = {
"type": "instructions",
"steps": [
{
"description": "Send an email notification",
"block_name": "GmailSendBlock",
},
{"description": "Fetch weather data", "action": "Get weather API"},
],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert "Send an email notification" in result
assert "GmailSendBlock" in result
assert "Fetch weather data" in result
assert "Get weather API" in result
def test_returns_empty_for_non_instructions_type(self):
"""Test that non-instructions types return empty list."""
decomposition_result = {
"type": "clarifying_questions",
"questions": [{"question": "What email?"}],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert result == []
def test_deduplicates_terms_case_insensitively(self):
"""Test that duplicate terms are removed (case-insensitive)."""
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "Send Email", "name": "send email"},
{"description": "Other task"},
],
}
result = core.extract_search_terms_from_steps(decomposition_result)
# Should only have one "send email" variant
email_terms = [t for t in result if "email" in t.lower()]
assert len(email_terms) == 1
def test_filters_short_terms(self):
"""Test that terms with 3 or fewer characters are filtered out."""
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "ab", "action": "xyz"}, # Both too short
{"description": "Valid term here"},
],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert "ab" not in result
assert "xyz" not in result
assert "Valid term here" in result
def test_handles_empty_steps(self):
"""Test handling of empty steps list."""
decomposition_result = {
"type": "instructions",
"steps": [],
}
result = core.extract_search_terms_from_steps(decomposition_result)
assert result == []
class TestEnrichLibraryAgentsFromSteps:
"""Test enrich_library_agents_from_steps function."""
@pytest.mark.asyncio
async def test_enriches_with_additional_agents(self):
"""Test that additional agents are found based on steps."""
existing_agents = [
{
"graph_id": "existing-123",
"graph_version": 1,
"name": "Existing Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
additional_agents = [
{
"graph_id": "new-456",
"graph_version": 1,
"name": "Email Agent",
"description": "For sending emails",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "Send email notification"},
],
}
with patch.object(
core,
"get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=additional_agents,
):
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should have both existing and new agents
assert len(result) == 2
names = [a["name"] for a in result]
assert "Existing Agent" in names
assert "Email Agent" in names
@pytest.mark.asyncio
async def test_deduplicates_by_graph_id(self):
"""Test that agents with same graph_id are not duplicated."""
existing_agents = [
{
"graph_id": "agent-123",
"graph_version": 1,
"name": "Existing Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
# Additional search returns same agent
additional_agents = [
{
"graph_id": "agent-123", # Same ID
"graph_version": 1,
"name": "Existing Agent Copy",
"description": "Same agent different name",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "instructions",
"steps": [{"description": "Some action"}],
}
with patch.object(
core,
"get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=additional_agents,
):
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should not duplicate
assert len(result) == 1
@pytest.mark.asyncio
async def test_deduplicates_by_name(self):
"""Test that agents with same name are not duplicated."""
existing_agents = [
{
"graph_id": "agent-123",
"graph_version": 1,
"name": "Email Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
# Additional search returns agent with same name but different ID
additional_agents = [
{
"graph_id": "agent-456", # Different ID
"graph_version": 1,
"name": "Email Agent", # Same name
"description": "Different agent same name",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "instructions",
"steps": [{"description": "Send email"}],
}
with patch.object(
core,
"get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=additional_agents,
):
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should not duplicate by name
assert len(result) == 1
assert result[0].get("graph_id") == "agent-123" # Original kept
@pytest.mark.asyncio
async def test_returns_existing_when_no_steps(self):
"""Test that existing agents are returned when no search terms extracted."""
existing_agents = [
{
"graph_id": "existing-123",
"graph_version": 1,
"name": "Existing Agent",
"description": "Already fetched",
"input_schema": {},
"output_schema": {},
}
]
decomposition_result = {
"type": "clarifying_questions", # Not instructions type
"questions": [],
}
result = await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should return existing unchanged
assert result == existing_agents
@pytest.mark.asyncio
async def test_limits_search_terms_to_three(self):
"""Test that only first 3 search terms are used."""
existing_agents = []
decomposition_result = {
"type": "instructions",
"steps": [
{"description": "First action"},
{"description": "Second action"},
{"description": "Third action"},
{"description": "Fourth action"},
{"description": "Fifth action"},
],
}
call_count = 0
async def mock_get_agents(*args, **kwargs):
nonlocal call_count
call_count += 1
return []
with patch.object(
core,
"get_all_relevant_agents_for_generation",
side_effect=mock_get_agents,
):
await core.enrich_library_agents_from_steps(
user_id="user-123",
decomposition_result=decomposition_result,
existing_agents=existing_agents,
)
# Should only make 3 calls (limited to first 3 terms)
assert call_count == 3
class TestExtractUuidsFromText:
"""Test extract_uuids_from_text function."""
def test_extracts_single_uuid(self):
"""Test extraction of a single UUID from text."""
text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task"
result = core.extract_uuids_from_text(text)
assert len(result) == 1
assert "46631191-e8a8-486f-ad90-84f89738321d" in result
def test_extracts_multiple_uuids(self):
"""Test extraction of multiple UUIDs from text."""
text = (
"Combine agents 11111111-1111-4111-8111-111111111111 "
"and 22222222-2222-4222-9222-222222222222"
)
result = core.extract_uuids_from_text(text)
assert len(result) == 2
assert "11111111-1111-4111-8111-111111111111" in result
assert "22222222-2222-4222-9222-222222222222" in result
def test_deduplicates_uuids(self):
"""Test that duplicate UUIDs are deduplicated."""
text = (
"Use 46631191-e8a8-486f-ad90-84f89738321d twice: "
"46631191-e8a8-486f-ad90-84f89738321d"
)
result = core.extract_uuids_from_text(text)
assert len(result) == 1
def test_normalizes_to_lowercase(self):
"""Test that UUIDs are normalized to lowercase."""
text = "Use 46631191-E8A8-486F-AD90-84F89738321D"
result = core.extract_uuids_from_text(text)
assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d"
def test_returns_empty_for_no_uuids(self):
"""Test that empty list is returned when no UUIDs found."""
text = "Create an email agent that sends notifications"
result = core.extract_uuids_from_text(text)
assert result == []
def test_ignores_invalid_uuids(self):
"""Test that invalid UUID-like strings are ignored."""
text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc"
result = core.extract_uuids_from_text(text)
# UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth)
assert len(result) == 0
class TestGetLibraryAgentById:
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
@pytest.mark.asyncio
async def test_returns_agent_when_found_by_graph_id(self):
"""Test that agent is returned when found by graph_id."""
mock_agent = MagicMock()
mock_agent.graph_id = "agent-123"
mock_agent.graph_version = 1
mock_agent.name = "Test Agent"
mock_agent.description = "Test description"
mock_agent.input_schema = {"properties": {}}
mock_agent.output_schema = {"properties": {}}
with patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=mock_agent,
):
result = await core.get_library_agent_by_id("user-123", "agent-123")
assert result is not None
assert result["graph_id"] == "agent-123"
assert result["name"] == "Test Agent"
@pytest.mark.asyncio
async def test_falls_back_to_library_agent_id(self):
"""Test that lookup falls back to library agent ID when graph_id not found."""
mock_agent = MagicMock()
mock_agent.graph_id = "graph-456" # Different from the lookup ID
mock_agent.graph_version = 1
mock_agent.name = "Library Agent"
mock_agent.description = "Found by library ID"
mock_agent.input_schema = {"properties": {}}
mock_agent.output_schema = {"properties": {}}
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=None, # Not found by graph_id
),
patch.object(
core.library_db,
"get_library_agent",
new_callable=AsyncMock,
return_value=mock_agent, # Found by library ID
),
):
result = await core.get_library_agent_by_id("user-123", "library-id-123")
assert result is not None
assert result["graph_id"] == "graph-456"
assert result["name"] == "Library Agent"
@pytest.mark.asyncio
async def test_returns_none_when_not_found_by_either_method(self):
"""Test that None is returned when agent not found by either method."""
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=None,
),
patch.object(
core.library_db,
"get_library_agent",
new_callable=AsyncMock,
side_effect=core.NotFoundError("Not found"),
),
):
result = await core.get_library_agent_by_id("user-123", "nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_returns_none_on_exception(self):
"""Test that None is returned when exception occurs in both lookups."""
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
side_effect=Exception("Database error"),
),
patch.object(
core.library_db,
"get_library_agent",
new_callable=AsyncMock,
side_effect=Exception("Database error"),
),
):
result = await core.get_library_agent_by_id("user-123", "agent-123")
assert result is None
@pytest.mark.asyncio
async def test_alias_works(self):
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
class TestGetAllRelevantAgentsWithUuids:
"""Test UUID extraction in get_all_relevant_agents_for_generation."""
@pytest.mark.asyncio
async def test_fetches_explicitly_mentioned_agents(self):
"""Test that agents mentioned by UUID are fetched directly."""
mock_agent = MagicMock()
mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d"
mock_agent.graph_version = 1
mock_agent.name = "Mentioned Agent"
mock_agent.description = "Explicitly mentioned"
mock_agent.input_schema = {}
mock_agent.output_schema = {}
mock_response = MagicMock()
mock_response.agents = []
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=mock_agent,
),
patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
),
):
result = await core.get_all_relevant_agents_for_generation(
user_id="user-123",
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
include_marketplace=False,
)
assert len(result) == 1
assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d"
assert result[0].get("name") == "Mentioned Agent"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
@pytest.mark.asyncio
async def test_decompose_goal_with_context(self):
"""Test decomposition with additional context enriched into description."""
"""Test decomposition with additional context."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
@@ -119,12 +119,9 @@ class TestDecomposeGoalExternal:
"Build a chatbot", context="Use Python"
)
expected_description = (
"Build a chatbot\n\nAdditional context from user:\nUse Python"
)
mock_client.post.assert_called_once_with(
"/api/decompose-description",
json={"description": expected_description},
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
)
@pytest.mark.asyncio
@@ -436,139 +433,5 @@ class TestGetBlocksExternal:
assert result is None
class TestLibraryAgentsPassthrough:
"""Test that library_agents are passed correctly in all requests."""
def setup_method(self):
"""Reset client singleton before each test."""
service._settings = None
service._client = None
@pytest.mark.asyncio
async def test_decompose_goal_passes_library_agents(self):
"""Test that library_agents are included in decompose goal payload."""
library_agents = [
{
"graph_id": "agent-123",
"graph_version": 1,
"name": "Email Sender",
"description": "Sends emails",
"input_schema": {"properties": {"to": {"type": "string"}}},
"output_schema": {"properties": {"sent": {"type": "boolean"}}},
},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
await service.decompose_goal_external(
"Send an email",
library_agents=library_agents,
)
# Verify library_agents was passed in the payload
call_args = mock_client.post.call_args
assert call_args[1]["json"]["library_agents"] == library_agents
@pytest.mark.asyncio
async def test_generate_agent_passes_library_agents(self):
"""Test that library_agents are included in generate agent payload."""
library_agents = [
{
"graph_id": "agent-456",
"graph_version": 2,
"name": "Data Fetcher",
"description": "Fetches data from API",
"input_schema": {"properties": {"url": {"type": "string"}}},
"output_schema": {"properties": {"data": {"type": "object"}}},
},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": {"name": "Test Agent", "nodes": []},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
await service.generate_agent_external(
{"steps": ["Step 1"]},
library_agents=library_agents,
)
# Verify library_agents was passed in the payload
call_args = mock_client.post.call_args
assert call_args[1]["json"]["library_agents"] == library_agents
@pytest.mark.asyncio
async def test_generate_agent_patch_passes_library_agents(self):
"""Test that library_agents are included in patch generation payload."""
library_agents = [
{
"graph_id": "agent-789",
"graph_version": 1,
"name": "Slack Notifier",
"description": "Sends Slack messages",
"input_schema": {"properties": {"message": {"type": "string"}}},
"output_schema": {"properties": {"success": {"type": "boolean"}}},
},
]
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"agent_json": {"name": "Updated Agent", "nodes": []},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
await service.generate_agent_patch_external(
"Add error handling",
{"name": "Original Agent", "nodes": []},
library_agents=library_agents,
)
# Verify library_agents was passed in the payload
call_args = mock_client.post.call_args
assert call_args[1]["json"]["library_agents"] == library_agents
@pytest.mark.asyncio
async def test_decompose_goal_without_library_agents(self):
"""Test that decompose goal works without library_agents."""
mock_response = MagicMock()
mock_response.json.return_value = {
"success": True,
"type": "instructions",
"steps": ["Step 1"],
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
with patch.object(service, "_get_client", return_value=mock_client):
await service.decompose_goal_external("Build a workflow")
# Verify library_agents was NOT passed when not provided
call_args = mock_client.post.call_args
assert "library_agents" not in call_args[1]["json"]
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -43,24 +43,19 @@ faker = Faker()
# Constants for data generation limits (reduced for E2E tests)
NUM_USERS = 15
NUM_AGENT_BLOCKS = 30
MIN_GRAPHS_PER_USER = 25
MAX_GRAPHS_PER_USER = 25
MIN_GRAPHS_PER_USER = 15
MAX_GRAPHS_PER_USER = 15
MIN_NODES_PER_GRAPH = 3
MAX_NODES_PER_GRAPH = 6
MIN_PRESETS_PER_USER = 2
MAX_PRESETS_PER_USER = 3
MIN_AGENTS_PER_USER = 25
MAX_AGENTS_PER_USER = 25
MIN_AGENTS_PER_USER = 15
MAX_AGENTS_PER_USER = 15
MIN_EXECUTIONS_PER_GRAPH = 2
MAX_EXECUTIONS_PER_GRAPH = 8
MIN_REVIEWS_PER_VERSION = 2
MAX_REVIEWS_PER_VERSION = 5
# Guaranteed minimums for marketplace tests (deterministic)
GUARANTEED_FEATURED_AGENTS = 8
GUARANTEED_FEATURED_CREATORS = 5
GUARANTEED_TOP_AGENTS = 10
def get_image():
"""Generate a consistent image URL using picsum.photos service."""
@@ -390,7 +385,7 @@ class TestDataCreator:
library_agents = []
for user in self.users:
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
num_agents = 10 # Create exactly 10 agents per user
# Get available graphs for this user
user_graphs = [
@@ -512,17 +507,14 @@ class TestDataCreator:
existing_profiles, min(num_creators, len(existing_profiles))
)
# Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators
num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5))
# Mark about 50% of creators as featured (more for testing)
num_featured = max(2, int(num_creators * 0.5))
num_featured = min(
num_featured, len(selected_profiles)
) # Don't exceed available profiles
featured_profile_ids = set(
random.sample([p.id for p in selected_profiles], num_featured)
)
print(
f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})"
)
for profile in selected_profiles:
try:
@@ -553,25 +545,21 @@ class TestDataCreator:
return profiles
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
"""Create test store submissions using the API function.
DETERMINISTIC: Guarantees minimum featured agents for E2E tests.
"""
"""Create test store submissions using the API function."""
print("Creating test store submissions...")
submissions = []
approved_submissions = []
featured_count = 0
submission_counter = 0
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
# Create a special test submission for test123@gmail.com
test_user = next(
(user for user in self.users if user["email"] == "test123@gmail.com"), None
)
if test_user and self.agent_graphs:
if test_user:
# Special test data for consistent testing
test_submission_data = {
"user_id": test_user["id"],
"agent_id": self.agent_graphs[0]["id"],
"agent_id": self.agent_graphs[0]["id"], # Use first available graph
"agent_version": 1,
"slug": "test-agent-submission",
"name": "Test Agent Submission",
@@ -592,24 +580,37 @@ class TestDataCreator:
submissions.append(test_submission.model_dump())
print("✅ Created special test store submission for test123@gmail.com")
# ALWAYS approve and feature the test submission
# Randomly approve, reject, or leave pending the test submission
if test_submission.store_listing_version_id:
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.store_listing_version_id,
is_approved=True,
external_comments="Test submission approved",
internal_comments="Auto-approved test submission",
reviewer_id=test_user["id"],
)
approved_submissions.append(approved_submission.model_dump())
print("✅ Approved test store submission")
random_value = random.random()
if random_value < 0.4: # 40% chance to approve
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.store_listing_version_id,
is_approved=True,
external_comments="Test submission approved",
internal_comments="Auto-approved test submission",
reviewer_id=test_user["id"],
)
approved_submissions.append(approved_submission.model_dump())
print("✅ Approved test store submission")
await prisma.storelistingversion.update(
where={"id": test_submission.store_listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
print("🌟 Marked test agent as FEATURED")
# Mark approved submission as featured
await prisma.storelistingversion.update(
where={"id": test_submission.store_listing_version_id},
data={"isFeatured": True},
)
print("🌟 Marked test agent as FEATURED")
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
await review_store_submission(
store_listing_version_id=test_submission.store_listing_version_id,
is_approved=False,
external_comments="Test submission rejected - needs improvements",
internal_comments="Auto-rejected test submission for E2E testing",
reviewer_id=test_user["id"],
)
print("❌ Rejected test store submission")
else: # 30% chance to leave pending (70% to 100%)
print("⏳ Left test submission pending for review")
except Exception as e:
print(f"Error creating test store submission: {e}")
@@ -619,6 +620,7 @@ class TestDataCreator:
# Create regular submissions for all users
for user in self.users:
# Get available graphs for this specific user
user_graphs = [
g for g in self.agent_graphs if g.get("userId") == user["id"]
]
@@ -629,17 +631,18 @@ class TestDataCreator:
)
continue
# Create exactly 4 store submissions per user
for submission_index in range(4):
graph = random.choice(user_graphs)
submission_counter += 1
try:
print(
f"Creating store submission for user {user['id']} with graph {graph['id']}"
f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})"
)
# Use the API function to create store submission with correct parameters
submission = await create_store_submission(
user_id=user["id"],
user_id=user["id"], # Must match graph's userId
agent_id=graph["id"],
agent_version=graph.get("version", 1),
slug=faker.slug(),
@@ -648,24 +651,22 @@ class TestDataCreator:
video_url=get_video_url() if random.random() < 0.3 else None,
image_urls=[get_image() for _ in range(3)],
description=faker.text(),
categories=[get_category()],
categories=[
get_category()
], # Single category from predefined list
changes_summary="Initial E2E test submission",
)
submissions.append(submission.model_dump())
print(f"✅ Created store submission: {submission.name}")
# Randomly approve, reject, or leave pending the submission
if submission.store_listing_version_id:
# DETERMINISTIC: First N submissions are always approved
# First GUARANTEED_FEATURED_AGENTS of those are always featured
should_approve = (
submission_counter <= GUARANTEED_TOP_AGENTS
or random.random() < 0.4
)
should_feature = featured_count < GUARANTEED_FEATURED_AGENTS
if should_approve:
random_value = random.random()
if random_value < 0.4: # 40% chance to approve
try:
# Pick a random user as the reviewer (admin)
reviewer_id = random.choice(self.users)["id"]
approved_submission = await review_store_submission(
store_listing_version_id=submission.store_listing_version_id,
is_approved=True,
@@ -680,7 +681,16 @@ class TestDataCreator:
f"✅ Approved store submission: {submission.name}"
)
if should_feature:
# Mark some agents as featured during creation (30% chance)
# More likely for creators and first submissions
is_creator = user["id"] in [
p.get("userId") for p in self.profiles
]
feature_chance = (
0.5 if is_creator else 0.2
) # 50% for creators, 20% for others
if random.random() < feature_chance:
try:
await prisma.storelistingversion.update(
where={
@@ -688,25 +698,8 @@ class TestDataCreator:
},
data={"isFeatured": True},
)
featured_count += 1
print(
f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}"
)
except Exception as e:
print(
f"Warning: Could not mark submission as featured: {e}"
)
elif random.random() < 0.2:
try:
await prisma.storelistingversion.update(
where={
"id": submission.store_listing_version_id
},
data={"isFeatured": True},
)
featured_count += 1
print(
f"🌟 Marked agent as FEATURED (bonus): {submission.name}"
f"🌟 Marked agent as FEATURED: {submission.name}"
)
except Exception as e:
print(
@@ -717,9 +710,11 @@ class TestDataCreator:
print(
f"Warning: Could not approve submission {submission.name}: {e}"
)
elif random.random() < 0.5:
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
try:
# Pick a random user as the reviewer (admin)
reviewer_id = random.choice(self.users)["id"]
await review_store_submission(
store_listing_version_id=submission.store_listing_version_id,
is_approved=False,
@@ -734,7 +729,7 @@ class TestDataCreator:
print(
f"Warning: Could not reject submission {submission.name}: {e}"
)
else:
else: # 30% chance to leave pending (70% to 100%)
print(
f"⏳ Left submission pending for review: {submission.name}"
)
@@ -748,13 +743,9 @@ class TestDataCreator:
traceback.print_exc()
continue
print("\n📊 Store Submissions Summary:")
print(f" Created: {len(submissions)}")
print(f" Approved: {len(approved_submissions)}")
print(
f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})"
f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}"
)
self.store_submissions = submissions
return submissions
@@ -834,15 +825,12 @@ class TestDataCreator:
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
print(f"✅ Library agents created: {len(self.library_agents)}")
print(f"✅ Creator profiles updated: {len(self.profiles)}")
print(f"✅ Store submissions created: {len(self.store_submissions)}")
print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)")
print(
f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)"
)
print(f"✅ API keys created: {len(self.api_keys)}")
print(f"✅ Presets created: {len(self.presets)}")
print("\n🎯 Deterministic Guarantees:")
print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}")
print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}")
print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}")
print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}")
print("\n🚀 Your E2E test database is ready to use!")

View File

@@ -1,9 +1,10 @@
"use client";
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { useRouter } from "next/navigation";
import { useEffect } from "react";
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
import { getHomepageRoute } from "@/lib/constants";
export default function OnboardingPage() {
const router = useRouter();
@@ -12,10 +13,12 @@ export default function OnboardingPage() {
async function redirectToStep() {
try {
// Check if onboarding is enabled (also gets chat flag for redirect)
const { shouldShowOnboarding } = await getOnboardingStatus();
const { shouldShowOnboarding, isChatEnabled } =
await getOnboardingStatus();
const homepageRoute = getHomepageRoute(isChatEnabled);
if (!shouldShowOnboarding) {
router.replace("/");
router.replace(homepageRoute);
return;
}
@@ -23,7 +26,7 @@ export default function OnboardingPage() {
// Handle completed onboarding
if (onboarding.completedSteps.includes("GET_RESULTS")) {
router.replace("/");
router.replace(homepageRoute);
return;
}

View File

@@ -1,8 +1,9 @@
import { getOnboardingStatus } from "@/app/api/helpers";
import BackendAPI from "@/lib/autogpt-server-api";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { revalidatePath } from "next/cache";
import { getHomepageRoute } from "@/lib/constants";
import BackendAPI from "@/lib/autogpt-server-api";
import { NextResponse } from "next/server";
import { revalidatePath } from "next/cache";
import { getOnboardingStatus } from "@/app/api/helpers";
// Handle the callback to complete the user session login
export async function GET(request: Request) {
@@ -26,12 +27,13 @@ export async function GET(request: Request) {
await api.createUser();
// Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding } = await getOnboardingStatus();
const { shouldShowOnboarding, isChatEnabled } =
await getOnboardingStatus();
if (shouldShowOnboarding) {
next = "/onboarding";
revalidatePath("/onboarding", "layout");
} else {
next = "/";
next = getHomepageRoute(isChatEnabled);
revalidatePath(next, "layout");
}
} catch (createUserError) {

View File

@@ -1,17 +1,6 @@
import { OAuthPopupResultMessage } from "./types";
import { NextResponse } from "next/server";
/**
* Safely encode a value as JSON for embedding in a script tag.
* Escapes characters that could break out of the script context to prevent XSS.
*/
function safeJsonStringify(value: unknown): string {
return JSON.stringify(value)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e")
.replace(/&/g, "\\u0026");
}
// This route is intended to be used as the callback for integration OAuth flows,
// controlled by the CredentialsInput component. The CredentialsInput opens the login
// page in a pop-up window, which then redirects to this route to close the loop.
@@ -34,13 +23,12 @@ export async function GET(request: Request) {
console.debug("Sending message to opener:", message);
// Return a response with the message as JSON and a script to close the window
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
return new NextResponse(
`
<html>
<body>
<script>
window.opener.postMessage(${safeJsonStringify(message)});
window.opener.postMessage(${JSON.stringify(message)});
window.close();
</script>
</body>

View File

@@ -857,7 +857,7 @@ export const CustomNode = React.memo(
})();
const hasAdvancedFields =
data.inputSchema?.properties &&
data.inputSchema &&
Object.entries(data.inputSchema.properties).some(([key, value]) => {
return (
value.advanced === true && !data.inputSchema.required?.includes(key)

View File

@@ -11,6 +11,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { usePathname, useSearchParams } from "next/navigation";
import { useRef } from "react";
import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
@@ -69,16 +70,41 @@ export function useCopilotShell() {
});
const stopStream = useChatStore((s) => s.stopStream);
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const isStreaming = useCopilotStore((s) => s.isStreaming);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
const pendingActionRef = useRef<(() => void) | null>(null);
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
async function stopCurrentStream() {
if (!currentSessionId) return;
setIsSwitchingSession(true);
await new Promise<void>((resolve) => {
const unsubscribe = onStreamComplete((completedId) => {
if (completedId === currentSessionId) {
clearTimeout(timeout);
unsubscribe();
resolve();
}
});
const timeout = setTimeout(() => {
unsubscribe();
resolve();
}, 3000);
stopStream(currentSessionId);
}
});
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
});
setIsSwitchingSession(false);
}
function selectSession(sessionId: string) {
if (sessionId === currentSessionId) return;
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
@@ -88,12 +114,7 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function handleNewChatClick() {
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
function startNewChat() {
resetPagination();
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
@@ -102,6 +123,32 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
selectSession(sessionId);
};
openInterruptModal(pendingActionRef.current);
} else {
selectSession(sessionId);
}
}
function handleNewChatClick() {
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
startNewChat();
};
openInterruptModal(pendingActionRef.current);
} else {
startNewChat();
}
}
return {
isMobile,
isDrawerOpen,

View File

@@ -26,20 +26,8 @@ export function buildCopilotChatUrl(prompt: string): string {
export function getQuickActions(): string[] {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
"Show me what I can automate",
"Design a custom workflow",
"Help me with content creation",
];
}
export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
if (width < 500) {
return "I'm a chef and I hate...";
}
if (width <= 1080) {
return "What's your role and what eats up most of your day?";
}
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}

View File

@@ -1,13 +1,6 @@
"use client";
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
import { Flag } from "@/services/feature-flags/use-get-flag";
import { type ReactNode } from "react";
import type { ReactNode } from "react";
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
export default function CopilotLayout({ children }: { children: ReactNode }) {
return (
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
<CopilotShell>{children}</CopilotShell>
</FeatureFlagPage>
);
return <CopilotShell>{children}</CopilotShell>;
}

View File

@@ -6,9 +6,7 @@ import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useEffect, useState } from "react";
import { useCopilotStore } from "./copilot-page-store";
import { getInputPlaceholder } from "./helpers";
import { useCopilotPage } from "./useCopilotPage";
export default function CopilotPage() {
@@ -16,25 +14,14 @@ export default function CopilotPage() {
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
useEffect(() => {
const handleResize = () => {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
};
handleResize();
window.addEventListener("resize", handleResize);
return () => window.removeEventListener("resize", handleResize);
}, []);
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
state;
const {
greetingName,
quickActions,
isLoading,
hasSession,
initialPrompt,
isReady,
} = state;
const {
handleQuickAction,
startChatWithPrompt,
@@ -42,6 +29,8 @@ export default function CopilotPage() {
handleStreamingChange,
} = handlers;
if (!isReady) return null;
if (hasSession) {
return (
<div className="flex h-full flex-col">
@@ -92,7 +81,7 @@ export default function CopilotPage() {
}
return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
<div className="w-full text-center">
{isLoading ? (
<div className="mx-auto max-w-2xl">
@@ -109,25 +98,25 @@ export default function CopilotPage() {
</div>
) : (
<>
<div className="mx-auto max-w-3xl">
<div className="mx-auto max-w-2xl">
<Text
variant="h3"
className="mb-1 !text-[1.375rem] text-zinc-700"
className="mb-3 !text-[1.375rem] text-zinc-700"
>
Hey, <span className="text-violet-600">{greetingName}</span>
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
What do you want to automate?
</Text>
<div className="mb-6">
<ChatInput
onSend={startChatWithPrompt}
placeholder={inputPlaceholder}
placeholder='You can search or just ask - e.g. "create a blog post outline"'
/>
</div>
</div>
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{quickActions.map((action) => (
<Button
key={action}
@@ -135,7 +124,7 @@ export default function CopilotPage() {
variant="outline"
size="small"
onClick={() => handleQuickAction(action)}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
>
{action}
</Button>

View File

@@ -3,11 +3,18 @@ import {
postV2CreateSession,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import {
Flag,
type FlagValues,
useGetFlag,
} from "@/services/feature-flags/use-get-flag";
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
import * as Sentry from "@sentry/nextjs";
import { useQueryClient } from "@tanstack/react-query";
import { useFlags } from "launchdarkly-react-client-sdk";
import { useRouter } from "next/navigation";
import { useEffect } from "react";
import { useCopilotStore } from "./copilot-page-store";
@@ -26,6 +33,22 @@ export function useCopilotPage() {
const isCreating = useCopilotStore((s) => s.isCreatingSession);
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
useEffect(() => {
if (isLoggedIn) {
completeStep("VISIT_COPILOT");
}
}, [completeStep, isLoggedIn]);
const isChatEnabled = useGetFlag(Flag.CHAT);
const flags = useFlags<FlagValues>();
const homepageRoute = getHomepageRoute(isChatEnabled);
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
const isFlagReady =
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
const greetingName = getGreetingName(user);
const quickActions = getQuickActions();
@@ -35,8 +58,11 @@ export function useCopilotPage() {
: undefined;
useEffect(() => {
if (isLoggedIn) completeStep("VISIT_COPILOT");
}, [completeStep, isLoggedIn]);
if (!isFlagReady) return;
if (isChatEnabled === false) {
router.replace(homepageRoute);
}
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
async function startChatWithPrompt(prompt: string) {
if (!prompt?.trim()) return;
@@ -90,6 +116,7 @@ export function useCopilotPage() {
isLoading: isUserLoading,
hasSession,
initialPrompt,
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
},
handlers: {
handleQuickAction,

View File

@@ -1,6 +1,8 @@
"use client";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { getHomepageRoute } from "@/lib/constants";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useSearchParams } from "next/navigation";
import { Suspense } from "react";
import { getErrorDetails } from "./helpers";
@@ -9,6 +11,8 @@ function ErrorPageContent() {
const searchParams = useSearchParams();
const errorMessage = searchParams.get("message");
const errorDetails = getErrorDetails(errorMessage);
const isChatEnabled = useGetFlag(Flag.CHAT);
const homepageRoute = getHomepageRoute(isChatEnabled);
function handleRetry() {
// Auth-related errors should redirect to login
@@ -26,7 +30,7 @@ function ErrorPageContent() {
}, 2000);
} else {
// For server/network errors, go to home
window.location.href = "/";
window.location.href = homepageRoute;
}
}

View File

@@ -1,5 +1,6 @@
"use server";
import { getHomepageRoute } from "@/lib/constants";
import BackendAPI from "@/lib/autogpt-server-api";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { loginFormSchema } from "@/types/auth";
@@ -37,8 +38,10 @@ export async function login(email: string, password: string) {
await api.createUser();
// Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding } = await getOnboardingStatus();
const next = shouldShowOnboarding ? "/onboarding" : "/";
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
const next = shouldShowOnboarding
? "/onboarding"
: getHomepageRoute(isChatEnabled);
return {
success: true,

View File

@@ -1,6 +1,8 @@
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { environment } from "@/services/environment";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { loginFormSchema, LoginProvider } from "@/types/auth";
import { zodResolver } from "@hookform/resolvers/zod";
import { useRouter, useSearchParams } from "next/navigation";
@@ -20,15 +22,17 @@ export function useLoginPage() {
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
const isCloudEnv = environment.isCloud();
const isChatEnabled = useGetFlag(Flag.CHAT);
const homepageRoute = getHomepageRoute(isChatEnabled);
// Get redirect destination from 'next' query parameter
const nextUrl = searchParams.get("next");
useEffect(() => {
if (isLoggedIn && !isLoggingIn) {
router.push(nextUrl || "/");
router.push(nextUrl || homepageRoute);
}
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
const form = useForm<z.infer<typeof loginFormSchema>>({
resolver: zodResolver(loginFormSchema),
@@ -94,7 +98,7 @@ export function useLoginPage() {
}
// Prefer URL's next parameter, then use backend-determined route
router.replace(nextUrl || result.next || "/");
router.replace(nextUrl || result.next || homepageRoute);
} catch (error) {
toast({
title:

View File

@@ -1,5 +1,6 @@
"use server";
import { getHomepageRoute } from "@/lib/constants";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { signupFormSchema } from "@/types/auth";
import * as Sentry from "@sentry/nextjs";
@@ -58,8 +59,10 @@ export async function signup(
}
// Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding } = await getOnboardingStatus();
const next = shouldShowOnboarding ? "/onboarding" : "/";
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
const next = shouldShowOnboarding
? "/onboarding"
: getHomepageRoute(isChatEnabled);
return { success: true, next };
} catch (err) {

View File

@@ -1,6 +1,8 @@
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { environment } from "@/services/environment";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { LoginProvider, signupFormSchema } from "@/types/auth";
import { zodResolver } from "@hookform/resolvers/zod";
import { useRouter, useSearchParams } from "next/navigation";
@@ -20,15 +22,17 @@ export function useSignupPage() {
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
const isCloudEnv = environment.isCloud();
const isChatEnabled = useGetFlag(Flag.CHAT);
const homepageRoute = getHomepageRoute(isChatEnabled);
// Get redirect destination from 'next' query parameter
const nextUrl = searchParams.get("next");
useEffect(() => {
if (isLoggedIn && !isSigningUp) {
router.push(nextUrl || "/");
router.push(nextUrl || homepageRoute);
}
}, [isLoggedIn, isSigningUp, nextUrl, router]);
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
const form = useForm<z.infer<typeof signupFormSchema>>({
resolver: zodResolver(signupFormSchema),
@@ -129,7 +133,7 @@ export function useSignupPage() {
}
// Prefer the URL's next parameter, then result.next (for onboarding), then default
const redirectTo = nextUrl || result.next || "/";
const redirectTo = nextUrl || result.next || homepageRoute;
router.replace(redirectTo);
} catch (error) {
setIsLoading(false);

View File

@@ -1,81 +0,0 @@
import { environment } from "@/services/environment";
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest } from "next/server";
/**
* SSE Proxy for task stream reconnection.
*
* This endpoint allows clients to reconnect to an ongoing or recently completed
* background task's stream. It replays missed messages from Redis Streams and
* subscribes to live updates if the task is still running.
*
* Client contract:
* 1. When receiving an operation_started event, store the task_id
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
* 3. Messages are replayed from the last_message_id position
* 4. Stream ends when "finish" event is received
*/
export async function GET(
request: NextRequest,
{ params }: { params: Promise<{ taskId: string }> },
) {
const { taskId } = await params;
const searchParams = request.nextUrl.searchParams;
const lastMessageId = searchParams.get("last_message_id") || "0-0";
try {
// Get auth token from server-side session
const token = await getServerAuthToken();
// Build backend URL
const backendUrl = environment.getAGPTServerBaseUrl();
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
streamUrl.searchParams.set("last_message_id", lastMessageId);
// Forward request to backend with auth header
const headers: Record<string, string> = {
Accept: "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(streamUrl.toString(), {
method: "GET",
headers,
});
if (!response.ok) {
const error = await response.text();
return new Response(error, {
status: response.status,
headers: { "Content-Type": "application/json" },
});
}
// Return the SSE stream directly
return new Response(response.body, {
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache, no-transform",
Connection: "keep-alive",
"X-Accel-Buffering": "no",
},
});
} catch (error) {
console.error("Task stream proxy error:", error);
return new Response(
JSON.stringify({
error: "Failed to connect to task stream",
detail: error instanceof Error ? error.message : String(error),
}),
{
status: 500,
headers: { "Content-Type": "application/json" },
},
);
}
}

View File

@@ -181,5 +181,6 @@ export async function getOnboardingStatus() {
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
return {
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
isChatEnabled: status.is_chat_enabled,
};
}

View File

@@ -917,28 +917,6 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/config/ttl": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Ttl Config",
"description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.",
"operationId": "getV2GetTtlConfig",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"additionalProperties": true,
"type": "object",
"title": "Response Getv2Getttlconfig"
}
}
}
}
}
}
},
"/api/chat/health": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -961,63 +939,6 @@
}
}
},
"/api/chat/operations/{operation_id}/complete": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Complete Operation",
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
"operationId": "postV2CompleteOperation",
"parameters": [
{
"name": "operation_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Operation Id" }
},
{
"name": "x-api-key",
"in": "header",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "X-Api-Key"
}
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OperationCompleteRequest"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Postv2Completeoperation"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/sessions": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -1101,7 +1022,7 @@
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
"operationId": "getV2GetSession",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1236,7 +1157,7 @@
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1274,94 +1195,6 @@
}
}
},
"/api/chat/tasks/{task_id}": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Task Status",
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2GetTaskStatus",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Getv2Gettaskstatus"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/tasks/{task_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Task",
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
"operationId": "getV2StreamTask",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
},
{
"name": "last_message_id",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
"default": "0-0",
"title": "Last Message Id"
},
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -6335,18 +6168,6 @@
"title": "AccuracyTrendsResponse",
"description": "Response model for accuracy trends and alerts."
},
"ActiveStreamInfo": {
"properties": {
"task_id": { "type": "string", "title": "Task Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" },
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" }
},
"type": "object",
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
"title": "ActiveStreamInfo",
"description": "Information about an active stream for reconnection."
},
"AddUserCreditsResponse": {
"properties": {
"new_balance": { "type": "integer", "title": "New Balance" },
@@ -8160,25 +7981,6 @@
]
},
"new_output": { "type": "boolean", "title": "New Output" },
"execution_count": {
"type": "integer",
"title": "Execution Count",
"default": 0
},
"success_rate": {
"anyOf": [{ "type": "number" }, { "type": "null" }],
"title": "Success Rate"
},
"avg_correctness_score": {
"anyOf": [{ "type": "number" }, { "type": "null" }],
"title": "Avg Correctness Score"
},
"recent_executions": {
"items": { "$ref": "#/components/schemas/RecentExecution" },
"type": "array",
"title": "Recent Executions",
"description": "List of recent executions with status, score, and summary"
},
"can_access_graph": {
"type": "boolean",
"title": "Can Access Graph"
@@ -9002,27 +8804,6 @@
],
"title": "OnboardingStep"
},
"OperationCompleteRequest": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"result": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
{ "type": "string" },
{ "type": "null" }
],
"title": "Result"
},
"error": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Error"
}
},
"type": "object",
"required": ["success"],
"title": "OperationCompleteRequest",
"description": "Request model for external completion webhook."
},
"Pagination": {
"properties": {
"total_items": {
@@ -9593,23 +9374,6 @@
"required": ["providers", "pagination"],
"title": "ProviderResponse"
},
"RecentExecution": {
"properties": {
"status": { "type": "string", "title": "Status" },
"correctness_score": {
"anyOf": [{ "type": "number" }, { "type": "null" }],
"title": "Correctness Score"
},
"activity_summary": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Activity Summary"
}
},
"type": "object",
"required": ["status"],
"title": "RecentExecution",
"description": "Summary of a recent execution for quality assessment.\n\nUsed by the LLM to understand the agent's recent performance with specific examples\nrather than just aggregate statistics."
},
"RefundRequest": {
"properties": {
"id": { "type": "string", "title": "Id" },
@@ -9878,12 +9642,6 @@
"items": { "additionalProperties": true, "type": "object" },
"type": "array",
"title": "Messages"
},
"active_stream": {
"anyOf": [
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
}
},
"type": "object",
@@ -10039,8 +9797,7 @@
"sub_heading": { "type": "string", "title": "Sub Heading" },
"description": { "type": "string", "title": "Description" },
"runs": { "type": "integer", "title": "Runs" },
"rating": { "type": "number", "title": "Rating" },
"agent_graph_id": { "type": "string", "title": "Agent Graph Id" }
"rating": { "type": "number", "title": "Rating" }
},
"type": "object",
"required": [
@@ -10052,8 +9809,7 @@
"sub_heading",
"description",
"runs",
"rating",
"agent_graph_id"
"rating"
],
"title": "StoreAgent"
},

View File

@@ -1,15 +1,27 @@
"use client";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { getHomepageRoute } from "@/lib/constants";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useRouter } from "next/navigation";
import { useEffect } from "react";
export default function Page() {
const isChatEnabled = useGetFlag(Flag.CHAT);
const router = useRouter();
const homepageRoute = getHomepageRoute(isChatEnabled);
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
const isFlagReady =
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
useEffect(() => {
router.replace("/copilot");
}, [router]);
useEffect(
function redirectToHomepage() {
if (!isFlagReady) return;
router.replace(homepageRoute);
},
[homepageRoute, isFlagReady, router],
);
return <LoadingSpinner size="large" cover />;
return null;
}

View File

@@ -1,6 +1,7 @@
"use client";
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
@@ -24,8 +25,8 @@ export function Chat({
}: ChatProps) {
const { urlSessionId } = useCopilotSessionId();
const hasHandledNotFoundRef = useRef(false);
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
const {
session,
messages,
isLoading,
isCreating,
@@ -37,18 +38,6 @@ export function Chat({
startPollingForOperation,
} = useChat({ urlSessionId });
// Extract active stream info for reconnection
const activeStream = (
session as {
active_stream?: {
task_id: string;
last_message_id: string;
operation_id: string;
tool_name: string;
};
}
)?.active_stream;
useEffect(() => {
if (!onSessionNotFound) return;
if (!urlSessionId) return;
@@ -64,7 +53,8 @@ export function Chat({
isCreating,
]);
const shouldShowLoader = showLoader && (isLoading || isCreating);
const shouldShowLoader =
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
return (
<div className={cn("flex h-full flex-col", className)}>
@@ -76,19 +66,21 @@ export function Chat({
<div className="flex flex-col items-center gap-3">
<LoadingSpinner size="large" className="text-neutral-400" />
<Text variant="body" className="text-zinc-500">
Loading your chat...
{isSwitchingSession
? "Switching chat..."
: "Loading your chat..."}
</Text>
</div>
</div>
)}
{/* Error State */}
{error && !isLoading && (
{error && !isLoading && !isSwitchingSession && (
<ChatErrorState error={error} onRetry={createSession} />
)}
{/* Session Content */}
{sessionId && !isLoading && !error && (
{sessionId && !isLoading && !error && !isSwitchingSession && (
<ChatContainer
sessionId={sessionId}
initialMessages={messages}
@@ -96,16 +88,6 @@ export function Chat({
className="flex-1"
onStreamingChange={onStreamingChange}
onOperationStarted={startPollingForOperation}
activeStream={
activeStream
? {
taskId: activeStream.task_id,
lastMessageId: activeStream.last_message_id,
operationId: activeStream.operation_id,
toolName: activeStream.tool_name,
}
: undefined
}
/>
)}
</main>

View File

@@ -1,159 +0,0 @@
# SSE Reconnection Contract for Long-Running Operations
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
## Overview
When a user triggers a long-running operation (like agent generation), the backend:
1. Spawns a background task that survives SSE disconnections
2. Returns an `operation_started` response with a `task_id`
3. Stores stream messages in Redis Streams for replay
Clients can reconnect to the task stream at any time to receive missed messages.
## Client-Side Flow
### 1. Receiving Operation Started
When you receive an `operation_started` tool response:
```typescript
// The response includes a task_id for reconnection
{
type: "operation_started",
tool_name: "generate_agent",
operation_id: "uuid-...",
task_id: "task-uuid-...", // <-- Store this for reconnection
message: "Operation started. You can close this tab."
}
```
### 2. Storing Task Info
Use the chat store to track the active task:
```typescript
import { useChatStore } from "./chat-store";
// When operation_started is received:
useChatStore.getState().setActiveTask(sessionId, {
taskId: response.task_id,
operationId: response.operation_id,
toolName: response.tool_name,
lastMessageId: "0",
});
```
### 3. Reconnecting to a Task
To reconnect (e.g., after page refresh or tab reopen):
```typescript
const { reconnectToTask, getActiveTask } = useChatStore.getState();
// Check if there's an active task for this session
const activeTask = getActiveTask(sessionId);
if (activeTask) {
// Reconnect to the task stream
await reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId, // Resume from last position
(chunk) => {
// Handle incoming chunks
console.log("Received chunk:", chunk);
},
);
}
```
### 4. Tracking Message Position
To enable precise replay, update the last message ID as chunks arrive:
```typescript
const { updateTaskLastMessageId } = useChatStore.getState();
function handleChunk(chunk: StreamChunk) {
// If chunk has an index/id, track it
if (chunk.idx !== undefined) {
updateTaskLastMessageId(sessionId, String(chunk.idx));
}
}
```
## API Endpoints
### Task Stream Reconnection
```
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
```
- `taskId`: The task ID from `operation_started`
- `last_message_id`: Last received message index (default: "0" for full replay)
Returns: SSE stream of missed messages + live updates
## Chunk Types
The reconnected stream follows the same Vercel AI SDK protocol:
| Type | Description |
| ----------------------- | ----------------------- |
| `start` | Message lifecycle start |
| `text-delta` | Streaming text content |
| `text-end` | Text block completed |
| `tool-output-available` | Tool result available |
| `finish` | Stream completed |
| `error` | Error occurred |
## Error Handling
If reconnection fails:
1. Check if task still exists (may have expired - default TTL: 1 hour)
2. Fall back to polling the session for final state
3. Show appropriate UI message to user
## Persistence Considerations
For robust reconnection across browser restarts:
```typescript
// Store in localStorage/sessionStorage
const ACTIVE_TASKS_KEY = "chat_active_tasks";
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
tasks[sessionId] = task;
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
}
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
}
```
## Backend Configuration
The following backend settings affect reconnection behavior:
| Setting | Default | Description |
| ------------------- | ------- | ---------------------------------- |
| `stream_ttl` | 3600s | How long streams are kept in Redis |
| `stream_max_length` | 1000 | Max messages per stream |
## Testing
To test reconnection locally:
1. Start a long-running operation (e.g., agent generation)
2. Note the `task_id` from the `operation_started` response
3. Close the browser tab
4. Reopen and call `reconnectToTask` with the saved `task_id`
5. Verify that missed messages are replayed
See the main README for full local development setup.

View File

@@ -1,16 +0,0 @@
/**
* Constants for the chat system.
*
* Centralizes magic strings and values used across chat components.
*/
// LocalStorage keys
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
// Redis Stream IDs
export const INITIAL_MESSAGE_ID = "0";
export const INITIAL_STREAM_ID = "0-0";
// TTL values (in milliseconds)
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour

View File

@@ -1,12 +1,6 @@
"use client";
import { create } from "zustand";
import {
ACTIVE_TASK_TTL_MS,
COMPLETED_STREAM_TTL_MS,
INITIAL_STREAM_ID,
STORAGE_KEY_ACTIVE_TASKS,
} from "./chat-constants";
import type {
ActiveStream,
StreamChunk,
@@ -14,59 +8,15 @@ import type {
StreamResult,
StreamStatus,
} from "./chat-types";
import { executeStream, executeTaskReconnect } from "./stream-executor";
import { executeStream } from "./stream-executor";
export interface ActiveTaskInfo {
taskId: string;
sessionId: string;
operationId: string;
toolName: string;
lastMessageId: string;
startedAt: number;
}
/** Load active tasks from localStorage */
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
if (typeof window === "undefined") return new Map();
try {
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
if (!stored) return new Map();
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
const now = Date.now();
const tasks = new Map<string, ActiveTaskInfo>();
// Filter out expired tasks
for (const [sessionId, task] of Object.entries(parsed)) {
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
tasks.set(sessionId, task);
}
}
return tasks;
} catch {
return new Map();
}
}
/** Save active tasks to localStorage */
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
if (typeof window === "undefined") return;
try {
const obj: Record<string, ActiveTaskInfo> = {};
for (const [sessionId, task] of tasks) {
obj[sessionId] = task;
}
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
} catch {
// Ignore storage errors
}
}
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
interface ChatStoreState {
activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>;
activeSessions: Set<string>;
streamCompleteCallbacks: Set<StreamCompleteCallback>;
/** Active tasks for SSE reconnection - keyed by sessionId */
activeTasks: Map<string, ActiveTaskInfo>;
}
interface ChatStoreActions {
@@ -91,24 +41,6 @@ interface ChatStoreActions {
unregisterActiveSession: (sessionId: string) => void;
isSessionActive: (sessionId: string) => boolean;
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
/** Track active task for SSE reconnection */
setActiveTask: (
sessionId: string,
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
) => void;
/** Get active task for a session */
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
/** Clear active task when operation completes */
clearActiveTask: (sessionId: string) => void;
/** Reconnect to an existing task stream */
reconnectToTask: (
sessionId: string,
taskId: string,
lastMessageId?: string,
onChunk?: (chunk: StreamChunk) => void,
) => Promise<void>;
/** Update last message ID for a task (for tracking replay position) */
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
}
type ChatStore = ChatStoreState & ChatStoreActions;
@@ -132,126 +64,18 @@ function cleanupExpiredStreams(
const now = Date.now();
const cleaned = new Map(completedStreams);
for (const [sessionId, result] of cleaned) {
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
cleaned.delete(sessionId);
}
}
return cleaned;
}
/**
* Finalize a stream by moving it from activeStreams to completedStreams.
* Also handles cleanup and notifications.
*/
function finalizeStream(
sessionId: string,
stream: ActiveStream,
onChunk: ((chunk: StreamChunk) => void) | undefined,
get: () => ChatStoreState & ChatStoreActions,
set: (state: Partial<ChatStoreState>) => void,
): void {
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId);
}
}
}
}
/**
* Clean up an existing stream for a session and move it to completed streams.
* Returns updated maps for both active and completed streams.
*/
function cleanupExistingStream(
sessionId: string,
activeStreams: Map<string, ActiveStream>,
completedStreams: Map<string, StreamResult>,
callbacks: Set<StreamCompleteCallback>,
): {
activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>;
} {
const newActiveStreams = new Map(activeStreams);
let newCompletedStreams = new Map(completedStreams);
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
return {
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
};
}
/**
* Create a new active stream with initial state.
*/
function createActiveStream(
sessionId: string,
onChunk?: (chunk: StreamChunk) => void,
): ActiveStream {
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
return {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
}
export const useChatStore = create<ChatStore>((set, get) => ({
activeStreams: new Map(),
completedStreams: new Map(),
activeSessions: new Set(),
streamCompleteCallbacks: new Set(),
activeTasks: loadPersistedTasks(),
startStream: async function startStream(
sessionId,
@@ -261,21 +85,45 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunk,
) {
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks;
// Clean up any existing stream for this session
const {
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
} = cleanupExistingStream(
sessionId,
state.activeStreams,
state.completedStreams,
callbacks,
);
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
const stream: ActiveStream = {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
// Create new stream
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
@@ -285,7 +133,36 @@ export const useChatStore = create<ChatStore>((set, get) => ({
try {
await executeStream(stream, message, isUserMessage, context);
} finally {
finalizeStream(sessionId, stream, onChunk, get, set);
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(
currentState.streamCompleteCallbacks,
sessionId,
);
}
}
}
}
},
@@ -409,93 +286,4 @@ export const useChatStore = create<ChatStore>((set, get) => ({
set({ streamCompleteCallbacks: cleanedCallbacks });
};
},
setActiveTask: function setActiveTask(sessionId, taskInfo) {
const state = get();
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...taskInfo,
sessionId,
startedAt: Date.now(),
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
getActiveTask: function getActiveTask(sessionId) {
return get().activeTasks.get(sessionId);
},
clearActiveTask: function clearActiveTask(sessionId) {
const state = get();
if (!state.activeTasks.has(sessionId)) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
reconnectToTask: async function reconnectToTask(
sessionId,
taskId,
lastMessageId = INITIAL_STREAM_ID,
onChunk,
) {
const state = get();
const callbacks = state.streamCompleteCallbacks;
// Clean up any existing stream for this session
const {
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
} = cleanupExistingStream(
sessionId,
state.activeStreams,
state.completedStreams,
callbacks,
);
// Create new stream for reconnection
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
try {
await executeTaskReconnect(stream, taskId, lastMessageId);
} finally {
finalizeStream(sessionId, stream, onChunk, get, set);
// Clear active task on completion
if (stream.status === "completed" || stream.status === "error") {
const taskState = get();
if (taskState.activeTasks.has(sessionId)) {
const newActiveTasks = new Map(taskState.activeTasks);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
}
}
}
},
updateTaskLastMessageId: function updateTaskLastMessageId(
sessionId,
lastMessageId,
) {
const state = get();
const task = state.activeTasks.get(sessionId);
if (!task) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...task,
lastMessageId,
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
}));

View File

@@ -4,7 +4,6 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
export interface StreamChunk {
type:
| "stream_start"
| "text_chunk"
| "text_ended"
| "tool_call"
@@ -16,7 +15,6 @@ export interface StreamChunk {
| "error"
| "usage"
| "stream_end";
taskId?: string;
timestamp?: string;
content?: string;
message?: string;
@@ -43,7 +41,7 @@ export interface StreamChunk {
}
export type VercelStreamChunk =
| { type: "start"; messageId: string; taskId?: string }
| { type: "start"; messageId: string }
| { type: "finish" }
| { type: "text-start"; id: string }
| { type: "text-delta"; id: string; delta: string }
@@ -94,70 +92,3 @@ export interface StreamResult {
}
export type StreamCompleteCallback = (sessionId: string) => void;
// Type guards for message types
/**
* Check if a message has a toolId property.
*/
export function hasToolId<T extends { type: string }>(
msg: T,
): msg is T & { toolId: string } {
return (
"toolId" in msg &&
typeof (msg as Record<string, unknown>).toolId === "string"
);
}
/**
* Check if a message has an operationId property.
*/
export function hasOperationId<T extends { type: string }>(
msg: T,
): msg is T & { operationId: string } {
return (
"operationId" in msg &&
typeof (msg as Record<string, unknown>).operationId === "string"
);
}
/**
* Check if a message has a toolCallId property.
*/
export function hasToolCallId<T extends { type: string }>(
msg: T,
): msg is T & { toolCallId: string } {
return (
"toolCallId" in msg &&
typeof (msg as Record<string, unknown>).toolCallId === "string"
);
}
/**
* Check if a message is an operation message type.
*/
export function isOperationMessage<T extends { type: string }>(
msg: T,
): msg is T & {
type: "operation_started" | "operation_pending" | "operation_in_progress";
} {
return (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
);
}
/**
* Get the tool ID from a message if available.
* Checks toolId, operationId, and toolCallId properties.
*/
export function getToolIdFromMessage<T extends { type: string }>(
msg: T,
): string | undefined {
const record = msg as Record<string, unknown>;
if (typeof record.toolId === "string") return record.toolId;
if (typeof record.operationId === "string") return record.operationId;
if (typeof record.toolCallId === "string") return record.toolCallId;
return undefined;
}

View File

@@ -2,6 +2,7 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { cn } from "@/lib/utils";
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
import { useEffect } from "react";
@@ -16,13 +17,6 @@ export interface ChatContainerProps {
className?: string;
onStreamingChange?: (isStreaming: boolean) => void;
onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
operationId: string;
toolName: string;
};
}
export function ChatContainer({
@@ -32,7 +26,6 @@ export function ChatContainer({
className,
onStreamingChange,
onOperationStarted,
activeStream,
}: ChatContainerProps) {
const {
messages,
@@ -48,13 +41,16 @@ export function ChatContainer({
initialMessages,
initialPrompt,
onOperationStarted,
activeStream,
});
useEffect(() => {
onStreamingChange?.(isStreaming);
}, [isStreaming, onStreamingChange]);
const breakpoint = useBreakpoint();
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
return (
<div
className={cn(
@@ -122,7 +118,11 @@ export function ChatContainer({
disabled={isStreaming || !sessionId}
isStreaming={isStreaming}
onStop={stopStreaming}
placeholder="What else can I help with?"
placeholder={
isMobile
? "You can search or just ask"
: 'You can search or just ask — e.g. "create a blog post outline"'
}
/>
</div>
</div>

View File

@@ -2,7 +2,6 @@ import { toast } from "sonner";
import type { StreamChunk } from "../../chat-types";
import type { HandlerDependencies } from "./handlers";
import {
getErrorDisplayMessage,
handleError,
handleLoginNeeded,
handleStreamEnd,
@@ -25,22 +24,16 @@ export function createStreamEventDispatcher(
chunk.type === "need_login" ||
chunk.type === "error"
) {
if (!deps.hasResponseRef.current) {
console.info("[ChatStream] First response chunk:", {
type: chunk.type,
sessionId: deps.sessionId,
});
}
deps.hasResponseRef.current = true;
}
switch (chunk.type) {
case "stream_start":
// Store task ID for SSE reconnection
if (chunk.taskId && deps.onActiveTaskStarted) {
deps.onActiveTaskStarted({
taskId: chunk.taskId,
operationId: chunk.taskId,
toolName: "chat",
toolCallId: "chat_stream",
});
}
break;
case "text_chunk":
handleTextChunk(chunk, deps);
break;
@@ -63,7 +56,11 @@ export function createStreamEventDispatcher(
break;
case "stream_end":
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
console.info("[ChatStream] Stream ended:", {
sessionId: deps.sessionId,
hasResponse: deps.hasResponseRef.current,
chunkCount: deps.streamingChunksRef.current.length,
});
handleStreamEnd(chunk, deps);
break;
@@ -73,7 +70,7 @@ export function createStreamEventDispatcher(
// Show toast at dispatcher level to avoid circular dependencies
if (!isRegionBlocked) {
toast.error("Chat Error", {
description: getErrorDisplayMessage(chunk),
description: chunk.message || chunk.content || "An error occurred",
});
}
break;

View File

@@ -18,19 +18,11 @@ export interface HandlerDependencies {
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
streamingChunksRef: MutableRefObject<string[]>;
hasResponseRef: MutableRefObject<boolean>;
textFinalizedRef: MutableRefObject<boolean>;
streamEndedRef: MutableRefObject<boolean>;
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
sessionId: string;
onOperationStarted?: () => void;
onActiveTaskStarted?: (taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) => void;
}
export function isRegionBlockedError(chunk: StreamChunk): boolean {
@@ -40,25 +32,6 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean {
return message.toLowerCase().includes("not available in your region");
}
export function getUserFriendlyErrorMessage(
code: string | undefined,
): string | undefined {
switch (code) {
case "TASK_EXPIRED":
return "This operation has expired. Please try again.";
case "TASK_NOT_FOUND":
return "Could not find the requested operation.";
case "ACCESS_DENIED":
return "You do not have access to this operation.";
case "QUEUE_OVERFLOW":
return "Connection was interrupted. Please refresh to continue.";
case "MODEL_NOT_AVAILABLE_REGION":
return "This model is not available in your region.";
default:
return undefined;
}
}
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
if (!chunk.content) return;
deps.setHasTextChunks(true);
@@ -73,15 +46,10 @@ export function handleTextEnded(
_chunk: StreamChunk,
deps: HandlerDependencies,
) {
if (deps.textFinalizedRef.current) {
return;
}
const completedText = deps.streamingChunksRef.current.join("");
if (completedText.trim()) {
deps.textFinalizedRef.current = true;
deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some(
(msg) =>
msg.type === "message" &&
@@ -108,14 +76,9 @@ export function handleToolCallStart(
chunk: StreamChunk,
deps: HandlerDependencies,
) {
// Use deterministic fallback instead of Date.now() to ensure same ID on replay
const toolId =
chunk.tool_id ||
`tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`;
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
type: "tool_call",
toolId,
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
toolName: chunk.tool_name || "Executing",
arguments: chunk.arguments || {},
timestamp: new Date(),
@@ -148,29 +111,6 @@ export function handleToolCallStart(
deps.setMessages(updateToolCallMessages);
}
const TOOL_RESPONSE_TYPES = new Set([
"tool_response",
"operation_started",
"operation_pending",
"operation_in_progress",
"execution_started",
"agent_carousel",
"clarification_needed",
]);
function hasResponseForTool(
messages: ChatMessageData[],
toolId: string,
): boolean {
return messages.some((msg) => {
if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false;
const msgToolId =
(msg as { toolId?: string }).toolId ||
(msg as { toolCallId?: string }).toolCallId;
return msgToolId === toolId;
});
}
export function handleToolResponse(
chunk: StreamChunk,
deps: HandlerDependencies,
@@ -212,49 +152,31 @@ export function handleToolResponse(
) {
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
if (inputsMessage) {
deps.setMessages((prev) => {
// Check for duplicate inputs_needed message
const exists = prev.some((msg) => msg.type === "inputs_needed");
if (exists) return prev;
return [...prev, inputsMessage];
});
deps.setMessages((prev) => [...prev, inputsMessage]);
}
const credentialsMessage = extractCredentialsNeeded(
parsedResult,
chunk.tool_name,
);
if (credentialsMessage) {
deps.setMessages((prev) => {
// Check for duplicate credentials_needed message
const exists = prev.some((msg) => msg.type === "credentials_needed");
if (exists) return prev;
return [...prev, credentialsMessage];
});
deps.setMessages((prev) => [...prev, credentialsMessage]);
}
}
return;
}
// Trigger polling when operation_started is received
if (responseMessage.type === "operation_started") {
deps.onOperationStarted?.();
const taskId = (responseMessage as { taskId?: string }).taskId;
if (taskId && deps.onActiveTaskStarted) {
deps.onActiveTaskStarted({
taskId,
operationId:
(responseMessage as { operationId?: string }).operationId || "",
toolName: (responseMessage as { toolName?: string }).toolName || "",
toolCallId: (responseMessage as { toolId?: string }).toolId || "",
});
}
}
deps.setMessages((prev) => {
const toolCallIndex = prev.findIndex(
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
);
if (hasResponseForTool(prev, chunk.tool_id!)) {
return prev;
}
const hasResponse = prev.some(
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
);
if (hasResponse) return prev;
if (toolCallIndex !== -1) {
const newMessages = [...prev];
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
@@ -276,48 +198,28 @@ export function handleLoginNeeded(
agentInfo: chunk.agent_info,
timestamp: new Date(),
};
deps.setMessages((prev) => {
// Check for duplicate login_needed message
const exists = prev.some((msg) => msg.type === "login_needed");
if (exists) return prev;
return [...prev, loginNeededMessage];
});
deps.setMessages((prev) => [...prev, loginNeededMessage]);
}
export function handleStreamEnd(
_chunk: StreamChunk,
deps: HandlerDependencies,
) {
if (deps.streamEndedRef.current) {
return;
}
deps.streamEndedRef.current = true;
const completedContent = deps.streamingChunksRef.current.join("");
if (!completedContent.trim() && !deps.hasResponseRef.current) {
deps.setMessages((prev) => {
const exists = prev.some(
(msg) =>
msg.type === "message" &&
msg.role === "assistant" &&
msg.content === "No response received. Please try again.",
);
if (exists) return prev;
return [
...prev,
{
type: "message",
role: "assistant",
content: "No response received. Please try again.",
timestamp: new Date(),
},
];
});
deps.setMessages((prev) => [
...prev,
{
type: "message",
role: "assistant",
content: "No response received. Please try again.",
timestamp: new Date(),
},
]);
}
if (completedContent.trim() && !deps.textFinalizedRef.current) {
deps.textFinalizedRef.current = true;
if (completedContent.trim()) {
deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some(
(msg) =>
msg.type === "message" &&
@@ -342,6 +244,8 @@ export function handleStreamEnd(
}
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
const errorMessage = chunk.message || chunk.content || "An error occurred";
console.error("Stream error:", errorMessage);
if (isRegionBlockedError(chunk)) {
deps.setIsRegionBlockedModalOpen(true);
}
@@ -349,14 +253,4 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
deps.setHasTextChunks(false);
deps.setStreamingChunks([]);
deps.streamingChunksRef.current = [];
deps.textFinalizedRef.current = false;
deps.streamEndedRef.current = true;
}
export function getErrorDisplayMessage(chunk: StreamChunk): string {
const friendlyMessage = getUserFriendlyErrorMessage(chunk.code);
if (friendlyMessage) {
return friendlyMessage;
}
return chunk.message || chunk.content || "An error occurred";
}

View File

@@ -349,7 +349,6 @@ export function parseToolResponse(
toolName: (parsedResult.tool_name as string) || toolName,
toolId,
operationId: (parsedResult.operation_id as string) || "",
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
message:
(parsedResult.message as string) ||
"Operation started. You can close this tab.",

View File

@@ -1,17 +1,10 @@
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import { useEffect, useMemo, useRef, useState } from "react";
import { INITIAL_STREAM_ID } from "../../chat-constants";
import { useChatStore } from "../../chat-store";
import { toast } from "sonner";
import { useChatStream } from "../../useChatStream";
import { usePageContext } from "../../usePageContext";
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
import {
getToolIdFromMessage,
hasToolId,
isOperationMessage,
type StreamChunk,
} from "../../chat-types";
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
import {
createUserMessage,
@@ -21,13 +14,6 @@ import {
processInitialMessages,
} from "./helpers";
const TOOL_RESULT_TYPES = new Set([
"tool_response",
"agent_carousel",
"execution_started",
"clarification_needed",
]);
// Helper to generate deduplication key for a message
function getMessageKey(msg: ChatMessageData): string {
if (msg.type === "message") {
@@ -37,18 +23,14 @@ function getMessageKey(msg: ChatMessageData): string {
return `msg:${msg.role}:${msg.content}`;
} else if (msg.type === "tool_call") {
return `toolcall:${msg.toolId}`;
} else if (TOOL_RESULT_TYPES.has(msg.type)) {
// Unified key for all tool result types - same toolId with different types
// (tool_response vs agent_carousel) should deduplicate to the same key
const toolId = getToolIdFromMessage(msg);
// If no toolId, fall back to content-based key to avoid empty key collisions
if (!toolId) {
return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`;
}
return `toolresult:${toolId}`;
} else if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg) || "";
return `op:${toolId}:${msg.toolName}`;
} else if (msg.type === "tool_response") {
return `toolresponse:${(msg as any).toolId}`;
} else if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
} else {
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
}
@@ -59,13 +41,6 @@ interface Args {
initialMessages: SessionDetailResponse["messages"];
initialPrompt?: string;
onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
operationId: string;
toolName: string;
};
}
export function useChatContainer({
@@ -73,7 +48,6 @@ export function useChatContainer({
initialMessages,
initialPrompt,
onOperationStarted,
activeStream,
}: Args) {
const [messages, setMessages] = useState<ChatMessageData[]>([]);
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
@@ -83,8 +57,6 @@ export function useChatContainer({
useState(false);
const hasResponseRef = useRef(false);
const streamingChunksRef = useRef<string[]>([]);
const textFinalizedRef = useRef(false);
const streamEndedRef = useRef(false);
const previousSessionIdRef = useRef<string | null>(null);
const {
error,
@@ -93,182 +65,44 @@ export function useChatContainer({
} = useChatStream();
const activeStreams = useChatStore((s) => s.activeStreams);
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
const setActiveTask = useChatStore((s) => s.setActiveTask);
const getActiveTask = useChatStore((s) => s.getActiveTask);
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
const isStreaming = isStreamingInitiated || hasTextChunks;
// Track whether we've already connected to this activeStream to avoid duplicate connections
const connectedActiveStreamRef = useRef<string | null>(null);
// Track if component is mounted to prevent state updates after unmount
const isMountedRef = useRef(true);
// Track current dispatcher to prevent multiple dispatchers from adding messages
const currentDispatcherIdRef = useRef(0);
// Set mounted flag - reset on every mount, cleanup on unmount
useEffect(function trackMountedState() {
isMountedRef.current = true;
return function cleanup() {
isMountedRef.current = false;
};
}, []);
// Callback to store active task info for SSE reconnection
function handleActiveTaskStarted(taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) {
if (!sessionId) return;
setActiveTask(sessionId, {
taskId: taskInfo.taskId,
operationId: taskInfo.operationId,
toolName: taskInfo.toolName,
lastMessageId: INITIAL_STREAM_ID,
});
}
// Create dispatcher for stream events - stable reference for current sessionId
// Each dispatcher gets a unique ID to prevent stale dispatchers from updating state
function createDispatcher() {
if (!sessionId) return () => {};
// Increment dispatcher ID - only the most recent dispatcher should update state
const dispatcherId = ++currentDispatcherIdRef.current;
const baseDispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
textFinalizedRef,
streamEndedRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
// Wrap dispatcher to check if it's still the current one
return function guardedDispatcher(chunk: StreamChunk) {
// Skip if component unmounted or this is a stale dispatcher
if (!isMountedRef.current) {
return;
}
if (dispatcherId !== currentDispatcherIdRef.current) {
return;
}
baseDispatcher(chunk);
};
}
useEffect(
function handleSessionChange() {
const isSessionChange = sessionId !== previousSessionIdRef.current;
if (sessionId === previousSessionIdRef.current) return;
// Handle session change - reset state
if (isSessionChange) {
const prevSession = previousSessionIdRef.current;
if (prevSession) {
stopStreaming(prevSession);
}
previousSessionIdRef.current = sessionId;
connectedActiveStreamRef.current = null;
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
textFinalizedRef.current = false;
streamEndedRef.current = false;
const prevSession = previousSessionIdRef.current;
if (prevSession) {
stopStreaming(prevSession);
}
previousSessionIdRef.current = sessionId;
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
if (!sessionId) return;
// Priority 1: Check if server told us there's an active stream (most authoritative)
if (activeStream) {
const streamKey = `${sessionId}:${activeStream.taskId}`;
const activeStream = activeStreams.get(sessionId);
if (!activeStream || activeStream.status !== "streaming") return;
if (connectedActiveStreamRef.current === streamKey) {
return;
}
// Skip if there's already an active stream for this session in the store
const existingStream = activeStreams.get(sessionId);
if (existingStream && existingStream.status === "streaming") {
connectedActiveStreamRef.current = streamKey;
return;
}
connectedActiveStreamRef.current = streamKey;
// Clear all state before reconnection to prevent duplicates
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
textFinalizedRef.current = false;
streamEndedRef.current = false;
hasResponseRef.current = false;
setIsStreamingInitiated(true);
setActiveTask(sessionId, {
taskId: activeStream.taskId,
operationId: activeStream.operationId,
toolName: activeStream.toolName,
lastMessageId: activeStream.lastMessageId,
});
reconnectToTask(
sessionId,
activeStream.taskId,
activeStream.lastMessageId,
createDispatcher(),
);
// Don't return cleanup here - the guarded dispatcher handles stale events
// and the stream will complete naturally. Cleanup would prematurely stop
// the stream when effect re-runs due to activeStreams changing.
return;
}
// Only check localStorage/in-memory on session change
if (!isSessionChange) return;
// Priority 2: Check localStorage for active task
const activeTask = getActiveTask(sessionId);
if (activeTask) {
// Clear all state before reconnection to prevent duplicates
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
textFinalizedRef.current = false;
streamEndedRef.current = false;
hasResponseRef.current = false;
setIsStreamingInitiated(true);
reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId,
createDispatcher(),
);
// Don't return cleanup here - the guarded dispatcher handles stale events
return;
}
// Priority 3: Check for an in-memory active stream (same-tab scenario)
const inMemoryStream = activeStreams.get(sessionId);
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
return;
}
const dispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
});
setIsStreamingInitiated(true);
const skipReplay = initialMessages.length > 0;
return subscribeToStream(sessionId, createDispatcher(), skipReplay);
return subscribeToStream(sessionId, dispatcher, skipReplay);
},
[
sessionId,
@@ -276,10 +110,6 @@ export function useChatContainer({
activeStreams,
subscribeToStream,
onOperationStarted,
getActiveTask,
reconnectToTask,
activeStream,
setActiveTask,
],
);
@@ -294,7 +124,7 @@ export function useChatContainer({
msg.type === "agent_carousel" ||
msg.type === "execution_started"
) {
const toolId = hasToolId(msg) ? msg.toolId : undefined;
const toolId = (msg as any).toolId;
if (toolId) {
ids.add(toolId);
}
@@ -311,8 +141,12 @@ export function useChatContainer({
setMessages((prev) => {
const filtered = prev.filter((msg) => {
if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg);
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) {
return false; // Remove - operation completed
}
@@ -340,8 +174,12 @@ export function useChatContainer({
// Filter local messages: remove duplicates and completed operation messages
const newLocalMessages = messages.filter((msg) => {
// Remove operation messages for completed tools
if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg);
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) {
return false;
}
@@ -352,70 +190,7 @@ export function useChatContainer({
});
// Server messages first (correct order), then new local messages
const combined = [...processedInitial, ...newLocalMessages];
// Post-processing: Remove duplicate assistant messages that can occur during
// race conditions (e.g., rapid screen switching during SSE reconnection).
// Two assistant messages are considered duplicates if:
// - They are both text messages with role "assistant"
// - One message's content starts with the other's content (partial vs complete)
// - Or they have very similar content (>80% overlap at the start)
const deduplicated: ChatMessageData[] = [];
for (let i = 0; i < combined.length; i++) {
const current = combined[i];
// Check if this is an assistant text message
if (current.type !== "message" || current.role !== "assistant") {
deduplicated.push(current);
continue;
}
// Look for duplicate assistant messages in the rest of the array
let dominated = false;
for (let j = 0; j < combined.length; j++) {
if (i === j) continue;
const other = combined[j];
if (other.type !== "message" || other.role !== "assistant") continue;
const currentContent = current.content || "";
const otherContent = other.content || "";
// Skip empty messages
if (!currentContent.trim() || !otherContent.trim()) continue;
// Check if current is a prefix of other (current is incomplete version)
if (
otherContent.length > currentContent.length &&
otherContent.startsWith(currentContent.slice(0, 100))
) {
// Current is a shorter/incomplete version of other - skip it
dominated = true;
break;
}
// Check if messages are nearly identical (within a small difference)
// This catches cases where content differs only slightly
const minLen = Math.min(currentContent.length, otherContent.length);
const compareLen = Math.min(minLen, 200); // Compare first 200 chars
if (
compareLen > 50 &&
currentContent.slice(0, compareLen) ===
otherContent.slice(0, compareLen)
) {
// Same prefix - keep the longer one
if (otherContent.length > currentContent.length) {
dominated = true;
break;
}
}
}
if (!dominated) {
deduplicated.push(current);
}
}
return deduplicated;
return [...processedInitial, ...newLocalMessages];
}, [initialMessages, messages, completedToolIds]);
async function sendMessage(
@@ -423,8 +198,10 @@ export function useChatContainer({
isUserMessage: boolean = true,
context?: { url: string; content: string },
) {
if (!sessionId) return;
if (!sessionId) {
console.error("[useChatContainer] Cannot send message: no session ID");
return;
}
setIsRegionBlockedModalOpen(false);
if (isUserMessage) {
const userMessage = createUserMessage(content);
@@ -437,19 +214,31 @@ export function useChatContainer({
setHasTextChunks(false);
setIsStreamingInitiated(true);
hasResponseRef.current = false;
textFinalizedRef.current = false;
streamEndedRef.current = false;
const dispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
});
try {
await sendStreamMessage(
sessionId,
content,
createDispatcher(),
dispatcher,
isUserMessage,
context,
);
} catch (err) {
console.error("[useChatContainer] Failed to send message:", err);
setIsStreamingInitiated(false);
if (err instanceof Error && err.name === "AbortError") return;
const errorMessage =

View File

@@ -57,7 +57,6 @@ export function ChatInput({
isStreaming,
value,
baseHandleKeyDown,
inputId,
});
return (
@@ -74,20 +73,19 @@ export function ChatInput({
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
)}
>
{!value && !isRecording && (
<div
className="pointer-events-none absolute inset-0 top-0.5 flex items-center justify-start pl-14 text-[1rem] text-zinc-400"
aria-hidden="true"
>
{isTranscribing ? "Transcribing..." : placeholder}
</div>
)}
<textarea
id={inputId}
aria-label="Chat message input"
value={value}
onChange={handleChange}
onKeyDown={handleKeyDown}
placeholder={
isTranscribing
? "Transcribing..."
: isRecording
? ""
: placeholder
}
disabled={isInputDisabled}
rows={1}
className={cn(
@@ -123,14 +121,13 @@ export function ChatInput({
size="icon"
aria-label={isRecording ? "Stop recording" : "Start recording"}
onClick={toggleRecording}
disabled={disabled || isTranscribing || isStreaming}
disabled={disabled || isTranscribing}
className={cn(
isRecording
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
: isTranscribing
? "border-zinc-300 bg-zinc-100 text-zinc-400"
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
isStreaming && "opacity-40",
)}
>
{isTranscribing ? (

View File

@@ -38,8 +38,8 @@ export function AudioWaveform({
// Create audio context and analyser
const audioContext = new AudioContext();
const analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
analyser.smoothingTimeConstant = 0.3;
analyser.fftSize = 512;
analyser.smoothingTimeConstant = 0.8;
// Connect the stream to the analyser
const source = audioContext.createMediaStreamSource(stream);
@@ -73,11 +73,10 @@ export function AudioWaveform({
maxAmplitude = Math.max(maxAmplitude, amplitude);
}
// Normalize amplitude (0-128 range) to 0-1
const normalized = maxAmplitude / 128;
// Apply sensitivity boost (multiply by 4) and use sqrt curve to amplify quiet sounds
const boosted = Math.min(1, Math.sqrt(normalized) * 4);
const height = minBarHeight + boosted * (maxBarHeight - minBarHeight);
// Map amplitude (0-128) to bar height
const normalized = (maxAmplitude / 128) * 255;
const height =
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
newBars.push(height);
}

View File

@@ -15,7 +15,6 @@ interface Args {
isStreaming?: boolean;
value: string;
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
inputId?: string;
}
export function useVoiceRecording({
@@ -24,7 +23,6 @@ export function useVoiceRecording({
isStreaming = false,
value,
baseHandleKeyDown,
inputId,
}: Args) {
const [isRecording, setIsRecording] = useState(false);
const [isTranscribing, setIsTranscribing] = useState(false);
@@ -105,7 +103,7 @@ export function useVoiceRecording({
setIsTranscribing(false);
}
},
[handleTranscription, inputId],
[handleTranscription],
);
const stopRecording = useCallback(() => {
@@ -203,15 +201,6 @@ export function useVoiceRecording({
}
}, [error, toast]);
useEffect(() => {
if (!isTranscribing && inputId) {
const inputElement = document.getElementById(inputId);
if (inputElement) {
inputElement.focus();
}
}
}, [isTranscribing, inputId]);
const handleKeyDown = useCallback(
(event: KeyboardEvent<HTMLTextAreaElement>) => {
if (event.key === " " && !value.trim() && !isTranscribing) {
@@ -224,7 +213,7 @@ export function useVoiceRecording({
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
);
const showMicButton = isSupported;
const showMicButton = isSupported && !isStreaming;
const isInputDisabled = disabled || isStreaming || isTranscribing;
// Cleanup on unmount

View File

@@ -156,19 +156,11 @@ export function ChatMessage({
}
if (isClarificationNeeded && message.type === "clarification_needed") {
const hasUserReplyAfter =
index >= 0 &&
messages
.slice(index + 1)
.some((m) => m.type === "message" && m.role === "user");
return (
<ClarificationQuestionsWidget
questions={message.questions}
message={message.message}
sessionId={message.sessionId}
onSubmitAnswers={handleClarificationAnswers}
isAnswered={hasUserReplyAfter}
className={className}
/>
);

View File

@@ -111,7 +111,6 @@ export type ChatMessageData =
toolName: string;
toolId: string;
operationId: string;
taskId?: string; // For SSE reconnection
message: string;
timestamp?: string | Date;
}

View File

@@ -6,7 +6,7 @@ import { Input } from "@/components/atoms/Input/Input";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
import { CheckCircleIcon, QuestionIcon } from "@phosphor-icons/react";
import { useState, useEffect, useRef } from "react";
import { useState } from "react";
export interface ClarifyingQuestion {
question: string;
@@ -17,96 +17,39 @@ export interface ClarifyingQuestion {
interface Props {
questions: ClarifyingQuestion[];
message: string;
sessionId?: string;
onSubmitAnswers: (answers: Record<string, string>) => void;
onCancel?: () => void;
isAnswered?: boolean;
className?: string;
}
function getStorageKey(sessionId?: string): string | null {
if (!sessionId) return null;
return `clarification_answers_${sessionId}`;
}
export function ClarificationQuestionsWidget({
questions,
message,
sessionId,
onSubmitAnswers,
onCancel,
isAnswered = false,
className,
}: Props) {
const [answers, setAnswers] = useState<Record<string, string>>({});
const [isSubmitted, setIsSubmitted] = useState(false);
const lastSessionIdRef = useRef<string | undefined>(undefined);
useEffect(() => {
const storageKey = getStorageKey(sessionId);
if (!storageKey) {
setAnswers({});
setIsSubmitted(false);
lastSessionIdRef.current = sessionId;
return;
}
try {
const saved = localStorage.getItem(storageKey);
if (saved) {
const parsed = JSON.parse(saved) as Record<string, string>;
setAnswers(parsed);
} else {
setAnswers({});
}
setIsSubmitted(false);
} catch {
setAnswers({});
setIsSubmitted(false);
}
lastSessionIdRef.current = sessionId;
}, [sessionId]);
useEffect(() => {
if (lastSessionIdRef.current !== sessionId) {
return;
}
const storageKey = getStorageKey(sessionId);
if (!storageKey) return;
const hasAnswers = Object.values(answers).some((v) => v.trim());
try {
if (hasAnswers) {
localStorage.setItem(storageKey, JSON.stringify(answers));
} else {
localStorage.removeItem(storageKey);
}
} catch {}
}, [answers, sessionId]);
function handleAnswerChange(keyword: string, value: string) {
setAnswers((prev) => ({ ...prev, [keyword]: value }));
}
function handleSubmit() {
// Check if all questions are answered
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
if (!allAnswered) {
return;
}
setIsSubmitted(true);
onSubmitAnswers(answers);
const storageKey = getStorageKey(sessionId);
try {
if (storageKey) {
localStorage.removeItem(storageKey);
}
} catch {}
}
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
if (isAnswered || isSubmitted) {
// Show submitted state after answers are submitted
if (isSubmitted) {
return (
<div
className={cn(

Some files were not shown because too many files have changed in this diff Show More