diff --git a/.dockerignore b/.dockerignore index 9b744e7f9b..427cab29f4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,42 +5,13 @@ !docs/ # Platform - Libs -!autogpt_platform/autogpt_libs/autogpt_libs/ -!autogpt_platform/autogpt_libs/pyproject.toml -!autogpt_platform/autogpt_libs/poetry.lock -!autogpt_platform/autogpt_libs/README.md +!autogpt_platform/autogpt_libs/ # Platform - Backend -!autogpt_platform/backend/backend/ -!autogpt_platform/backend/test/e2e_test_data.py -!autogpt_platform/backend/migrations/ -!autogpt_platform/backend/schema.prisma -!autogpt_platform/backend/pyproject.toml -!autogpt_platform/backend/poetry.lock -!autogpt_platform/backend/README.md -!autogpt_platform/backend/.env -!autogpt_platform/backend/gen_prisma_types_stub.py - -# Platform - Market -!autogpt_platform/market/market/ -!autogpt_platform/market/scripts.py -!autogpt_platform/market/schema.prisma -!autogpt_platform/market/pyproject.toml -!autogpt_platform/market/poetry.lock -!autogpt_platform/market/README.md +!autogpt_platform/backend/ # Platform - Frontend -!autogpt_platform/frontend/src/ -!autogpt_platform/frontend/public/ -!autogpt_platform/frontend/scripts/ -!autogpt_platform/frontend/package.json -!autogpt_platform/frontend/pnpm-lock.yaml -!autogpt_platform/frontend/tsconfig.json -!autogpt_platform/frontend/README.md -## config -!autogpt_platform/frontend/*.config.* -!autogpt_platform/frontend/.env.* -!autogpt_platform/frontend/.env +!autogpt_platform/frontend/ # Classic - AutoGPT !classic/original_autogpt/autogpt/ @@ -64,6 +35,38 @@ # Classic - Frontend !classic/frontend/build/web/ -# Explicitly re-ignore some folders -.* -**/__pycache__ +# Explicitly re-ignore unwanted files from whitelisted directories +# Note: These patterns MUST come after the whitelist rules to take effect + +# Hidden files and directories (but keep frontend .env files needed for build) +**/.* +!autogpt_platform/frontend/.env +!autogpt_platform/frontend/.env.default +!autogpt_platform/frontend/.env.production + +# Python artifacts +**/__pycache__/ +**/*.pyc +**/*.pyo +**/.venv/ +**/.ruff_cache/ +**/.pytest_cache/ +**/.coverage +**/htmlcov/ + +# Node artifacts +**/node_modules/ +**/.next/ +**/storybook-static/ +**/playwright-report/ +**/test-results/ + +# Build artifacts +**/dist/ +**/build/ +!autogpt_platform/frontend/src/**/build/ +**/target/ + +# Logs and temp files +**/*.log +**/*.tmp diff --git a/.github/scripts/detect_overlaps.py b/.github/scripts/detect_overlaps.py new file mode 100644 index 0000000000..1f9f4be7cf --- /dev/null +++ b/.github/scripts/detect_overlaps.py @@ -0,0 +1,1229 @@ +#!/usr/bin/env python3 +""" +PR Overlap Detection Tool + +Detects potential merge conflicts between a given PR and other open PRs +by checking for file overlap, line overlap, and actual merge conflicts. +""" + +import json +import os +import re +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from typing import Optional + + +# ============================================================================= +# MAIN ENTRY POINT +# ============================================================================= + +def main(): + """Main entry point for PR overlap detection.""" + import argparse + + parser = argparse.ArgumentParser(description="Detect PR overlaps and potential merge conflicts") + parser.add_argument("pr_number", type=int, help="PR number to check") + parser.add_argument("--base", default=None, help="Base branch (default: auto-detect from PR)") + parser.add_argument("--skip-merge-test", action="store_true", help="Skip actual merge conflict testing") + parser.add_argument("--discord-webhook", default=os.environ.get("DISCORD_WEBHOOK_URL"), help="Discord webhook URL for notifications") + parser.add_argument("--dry-run", action="store_true", help="Don't post comments, just print") + + args = parser.parse_args() + + owner, repo = get_repo_info() + print(f"Checking PR #{args.pr_number} in {owner}/{repo}") + + # Get current PR info + current_pr = fetch_pr_details(args.pr_number) + base_branch = args.base or current_pr.base_ref + + print(f"PR #{current_pr.number}: {current_pr.title}") + print(f"Base branch: {base_branch}") + print(f"Files changed: {len(current_pr.files)}") + + # Find overlapping PRs + overlaps, all_changes = find_overlapping_prs( + owner, repo, base_branch, current_pr, args.pr_number, args.skip_merge_test + ) + + if not overlaps: + print("No overlaps detected!") + return + + # Generate and post report + comment = format_comment(overlaps, args.pr_number, current_pr.changed_ranges, all_changes) + + if args.dry_run: + print("\n" + "="*60) + print("COMMENT PREVIEW:") + print("="*60) + print(comment) + else: + if comment: + post_or_update_comment(args.pr_number, comment) + print("Posted comment to PR") + + if args.discord_webhook: + send_discord_notification(args.discord_webhook, current_pr, overlaps) + + # Report results and exit + report_results(overlaps) + + +# ============================================================================= +# HIGH-LEVEL WORKFLOW FUNCTIONS +# ============================================================================= + +def fetch_pr_details(pr_number: int) -> "PullRequest": + """Fetch details for a specific PR including its diff.""" + result = run_gh(["pr", "view", str(pr_number), "--json", "number,title,url,author,headRefName,baseRefName,files"]) + data = json.loads(result.stdout) + + pr = PullRequest( + number=data["number"], + title=data["title"], + author=data["author"]["login"] if data.get("author") else "unknown", + url=data["url"], + head_ref=data["headRefName"], + base_ref=data["baseRefName"], + files=[f["path"] for f in data["files"]], + changed_ranges={} + ) + + # Get detailed diff + diff = get_pr_diff(pr_number) + pr.changed_ranges = parse_diff_ranges(diff) + + return pr + + +def find_overlapping_prs( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + current_pr_number: int, + skip_merge_test: bool +) -> tuple[list["Overlap"], dict[int, dict[str, "ChangedFile"]]]: + """Find all PRs that overlap with the current PR.""" + # Query other open PRs + all_prs = query_open_prs(owner, repo, base_branch) + other_prs = [p for p in all_prs if p["number"] != current_pr_number] + + print(f"Found {len(other_prs)} other open PRs targeting {base_branch}") + + # Find file overlaps (excluding ignored files, filtering by age) + candidates = find_file_overlap_candidates(current_pr.files, other_prs) + + print(f"Found {len(candidates)} PRs with file overlap (excluding ignored files)") + + if not candidates: + return [], {} + + # First pass: analyze line overlaps (no merge testing yet) + overlaps = [] + all_changes = {} + prs_needing_merge_test = [] + + for pr_data, shared_files in candidates: + overlap, pr_changes = analyze_pr_overlap( + owner, repo, base_branch, current_pr, pr_data, shared_files, + skip_merge_test=True # Always skip in first pass + ) + if overlap: + overlaps.append(overlap) + all_changes[pr_data["number"]] = pr_changes + # Track PRs that need merge testing + if overlap.line_overlaps and not skip_merge_test: + prs_needing_merge_test.append(overlap) + + # Second pass: batch merge testing with shared clone + if prs_needing_merge_test: + run_batch_merge_tests(owner, repo, base_branch, current_pr, prs_needing_merge_test) + + return overlaps, all_changes + + +def run_batch_merge_tests( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + overlaps: list["Overlap"] +): + """Run merge tests for multiple PRs using a shared clone.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone once + if not clone_repo(owner, repo, base_branch, tmpdir): + return + + configure_git(tmpdir) + + # Fetch current PR branch once + result = run_git(["fetch", "origin", f"pull/{current_pr.number}/head:pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch current PR #{current_pr.number}", file=sys.stderr) + return + + for overlap in overlaps: + other_pr = overlap.pr_b if overlap.pr_a.number == current_pr.number else overlap.pr_a + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + + # Clean up any in-progress merge from previous iteration + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + # Reset to base branch + run_git(["checkout", base_branch], cwd=tmpdir, check=False) + run_git(["reset", "--hard", f"origin/{base_branch}"], cwd=tmpdir, check=False) + run_git(["clean", "-fdx"], cwd=tmpdir, check=False) + + # Fetch the other PR branch + result = run_git(["fetch", "origin", f"pull/{other_pr.number}/head:pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{other_pr.number}: {result.stderr.strip()}", file=sys.stderr) + continue + + # Try merging current PR first + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Current PR conflicts with base + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'pr_a_conflicts_base' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + continue + + # Commit and try merging other PR + run_git(["commit", "-m", f"Merge PR #{current_pr.number}"], cwd=tmpdir, check=False) + + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Conflict between PRs + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'conflict' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + +def analyze_pr_overlap( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + other_pr_data: dict, + shared_files: list[str], + skip_merge_test: bool +) -> tuple[Optional["Overlap"], dict[str, "ChangedFile"]]: + """Analyze overlap between current PR and another PR.""" + # Filter out ignored files + non_ignored_shared = [f for f in shared_files if not should_ignore_file(f)] + if not non_ignored_shared: + return None, {} + + other_pr = PullRequest( + number=other_pr_data["number"], + title=other_pr_data["title"], + author=other_pr_data["author"], + url=other_pr_data["url"], + head_ref=other_pr_data["head_ref"], + base_ref=other_pr_data["base_ref"], + files=other_pr_data["files"], + changed_ranges={}, + updated_at=other_pr_data.get("updated_at") + ) + + # Get diff for other PR + other_diff = get_pr_diff(other_pr.number) + other_pr.changed_ranges = parse_diff_ranges(other_diff) + + # Check line overlaps + line_overlaps = find_line_overlaps( + current_pr.changed_ranges, + other_pr.changed_ranges, + shared_files + ) + + overlap = Overlap( + pr_a=current_pr, + pr_b=other_pr, + overlapping_files=non_ignored_shared, + line_overlaps=line_overlaps + ) + + # Test for actual merge conflicts if we have line overlaps + if line_overlaps and not skip_merge_test: + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + has_conflict, conflict_files, conflict_details, error_type = test_merge_conflict( + owner, repo, base_branch, current_pr, other_pr + ) + overlap.has_merge_conflict = has_conflict + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = error_type + + return overlap, other_pr.changed_ranges + + +def find_file_overlap_candidates( + current_files: list[str], + other_prs: list[dict], + max_age_days: int = 14 +) -> list[tuple[dict, list[str]]]: + """Find PRs that share files with the current PR.""" + from datetime import datetime, timezone, timedelta + + current_files_set = set(f for f in current_files if not should_ignore_file(f)) + candidates = [] + cutoff_date = datetime.now(timezone.utc) - timedelta(days=max_age_days) + + for pr_data in other_prs: + # Filter out PRs older than max_age_days + updated_at = pr_data.get("updated_at") + if updated_at: + try: + pr_date = datetime.fromisoformat(updated_at.replace('Z', '+00:00')) + if pr_date < cutoff_date: + continue # Skip old PRs + except Exception as e: + # If we can't parse date, include the PR (safe fallback) + print(f"Warning: Could not parse date for PR: {e}", file=sys.stderr) + + other_files = set(f for f in pr_data["files"] if not should_ignore_file(f)) + shared = current_files_set & other_files + + if shared: + candidates.append((pr_data, list(shared))) + + return candidates + + +def report_results(overlaps: list["Overlap"]): + """Report results (informational only, always exits 0).""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if conflicts: + print(f"\n⚠️ Found {len(conflicts)} merge conflict(s)") + + line_overlap_count = len([o for o in overlaps if o.line_overlaps]) + if line_overlap_count: + print(f"\n⚠️ Found {line_overlap_count} PR(s) with line overlap") + + print("\n✅ Done") + # Always exit 0 - this check is informational, not a merge blocker + + +# ============================================================================= +# COMMENT FORMATTING +# ============================================================================= + +def format_comment( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict[str, "ChangedFile"], + all_changes: dict[int, dict[str, "ChangedFile"]] +) -> str: + """Format the overlap report as a PR comment.""" + if not overlaps: + return "" + + lines = ["## 🔍 PR Overlap Detection"] + lines.append("") + lines.append("This check compares your PR against all other open PRs targeting the same branch to detect potential merge conflicts early.") + lines.append("") + + # Check if current PR conflicts with base branch + format_base_conflicts(overlaps, lines) + + # Classify and sort overlaps + classified = classify_all_overlaps(overlaps, current_pr, changes_current, all_changes) + + # Group by risk + conflicts = [(o, r) for o, r in classified if r == 'conflict'] + medium_risk = [(o, r) for o, r in classified if r == 'medium'] + low_risk = [(o, r) for o, r in classified if r == 'low'] + + # Format each section + format_conflicts_section(conflicts, current_pr, lines) + format_medium_risk_section(medium_risk, current_pr, changes_current, all_changes, lines) + format_low_risk_section(low_risk, current_pr, lines) + + # Summary + total = len(overlaps) + lines.append(f"\n**Summary:** {len(conflicts)} conflict(s), {len(medium_risk)} medium risk, {len(low_risk)} low risk (out of {total} PRs with file overlap)") + lines.append("\n---\n*Auto-generated on push. Ignores: `openapi.json`, lock files.*") + + return "\n".join(lines) + + +def format_base_conflicts(overlaps: list["Overlap"], lines: list[str]): + """Format base branch conflicts section.""" + base_conflicts = [o for o in overlaps if o.conflict_type == 'pr_a_conflicts_base'] + if base_conflicts: + lines.append("### ⚠️ This PR has conflicts with the base branch\n") + lines.append("Conflicts will need to be resolved before merging:\n") + first = base_conflicts[0] + for f in first.conflict_files[:10]: + lines.append(f"- `{f}`") + if len(first.conflict_files) > 10: + lines.append(f"- ... and {len(first.conflict_files) - 10} more files") + lines.append("\n") + + +def format_conflicts_section(conflicts: list[tuple], current_pr: int, lines: list[str]): + """Format the merge conflicts section.""" + pr_conflicts = [(o, r) for o, r in conflicts if o.conflict_type != 'pr_a_conflicts_base'] + + if not pr_conflicts: + return + + lines.append("### 🔴 Merge Conflicts Detected") + lines.append("") + lines.append("The following PRs have been tested and **will have merge conflicts** if merged after this PR. Consider coordinating with the authors.") + lines.append("") + + for o, _ in pr_conflicts: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + format_pr_entry(other, lines) + format_conflict_details(o, lines) + lines.append("") + + +def format_medium_risk_section( + medium_risk: list[tuple], + current_pr: int, + changes_current: dict, + all_changes: dict, + lines: list[str] +): + """Format the medium risk section.""" + if not medium_risk: + return + + lines.append("### 🟡 Medium Risk — Some Line Overlap\n") + lines.append("These PRs have some overlapping changes:\n") + + for o, _ in medium_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other.number, {}) + format_pr_entry(other, lines) + + # Note if rename is involved + for file_path in o.overlapping_files: + file_a = changes_current.get(file_path) + file_b = other_changes.get(file_path) + if (file_a and file_a.is_rename) or (file_b and file_b.is_rename): + lines.append(f" - ⚠️ `{file_path}` is being renamed/moved") + break + + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") + + +def format_low_risk_section(low_risk: list[tuple], current_pr: int, lines: list[str]): + """Format the low risk section.""" + if not low_risk: + return + + lines.append("### 🟢 Low Risk — File Overlap Only\n") + lines.append("
These PRs touch the same files but different sections (click to expand)\n") + + for o, _ in low_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + format_pr_entry(other, lines) + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") # Add blank line between entries + + lines.append("
\n") + + +def format_pr_entry(pr: "PullRequest", lines: list[str]): + """Format a single PR entry line.""" + updated = format_relative_time(pr.updated_at) + updated_str = f" · updated {updated}" if updated else "" + # Just use #number - GitHub auto-renders it with title + lines.append(f"- #{pr.number} ({pr.author}{updated_str})") + + +def format_conflict_details(overlap: "Overlap", lines: list[str]): + """Format conflict details for a PR.""" + if overlap.conflict_details: + all_paths = [d.path for d in overlap.conflict_details] + common_prefix = find_common_prefix(all_paths) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for detail in overlap.conflict_details: + display_path = detail.path[len(common_prefix):] if common_prefix else detail.path + size_str = format_conflict_size(detail) + lines.append(f" - `{display_path}`{size_str}") + elif overlap.conflict_files: + common_prefix = find_common_prefix(overlap.conflict_files) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for f in overlap.conflict_files: + display_path = f[len(common_prefix):] if common_prefix else f + lines.append(f" - `{display_path}`") + + +def format_conflict_size(detail: "ConflictInfo") -> str: + """Format conflict size string for a file.""" + if detail.conflict_count > 0: + return f" ({detail.conflict_count} conflict{'s' if detail.conflict_count > 1 else ''}, ~{detail.conflict_lines} lines)" + elif detail.conflict_type != 'content': + type_labels = { + 'both_added': 'added in both', + 'both_deleted': 'deleted in both', + 'deleted_by_us': 'deleted here, modified there', + 'deleted_by_them': 'modified here, deleted there', + 'added_by_us': 'added here', + 'added_by_them': 'added there', + } + label = type_labels.get(detail.conflict_type, detail.conflict_type) + return f" ({label})" + return "" + + +def format_line_overlaps(line_overlaps: dict[str, list[tuple]], lines: list[str]): + """Format line overlap details.""" + all_paths = list(line_overlaps.keys()) + common_prefix = find_common_prefix(all_paths) if len(all_paths) > 1 else "" + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for file_path, ranges in line_overlaps.items(): + display_path = file_path[len(common_prefix):] if common_prefix else file_path + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + indent = " " if common_prefix else " " + lines.append(f"{indent}- `{display_path}`: {', '.join(range_strs)}") + + +# ============================================================================= +# OVERLAP ANALYSIS +# ============================================================================= + +def classify_all_overlaps( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict, + all_changes: dict +) -> list[tuple["Overlap", str]]: + """Classify all overlaps by risk level and sort them.""" + classified = [] + for o in overlaps: + other_pr = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other_pr.number, {}) + risk = classify_overlap_risk(o, changes_current, other_changes) + classified.append((o, risk)) + + def sort_key(item): + o, risk = item + risk_order = {'conflict': 0, 'medium': 1, 'low': 2} + # For conflicts, also sort by total conflict lines (descending) + conflict_lines = sum(d.conflict_lines for d in o.conflict_details) if o.conflict_details else 0 + return (risk_order.get(risk, 99), -conflict_lines) + + classified.sort(key=sort_key) + + return classified + + +def classify_overlap_risk( + overlap: "Overlap", + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"] +) -> str: + """Classify the risk level of an overlap.""" + if overlap.has_merge_conflict: + return 'conflict' + + has_rename = any( + (changes_a.get(f) and changes_a[f].is_rename) or + (changes_b.get(f) and changes_b[f].is_rename) + for f in overlap.overlapping_files + ) + + if overlap.line_overlaps: + total_overlap_lines = sum( + end - start + 1 + for ranges in overlap.line_overlaps.values() + for start, end in ranges + ) + + # Medium risk: >20 lines overlap or file rename + if total_overlap_lines > 20 or has_rename: + return 'medium' + else: + return 'low' + + if has_rename: + return 'medium' + + return 'low' + + +def find_line_overlaps( + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"], + shared_files: list[str] +) -> dict[str, list[tuple[int, int]]]: + """Find overlapping line ranges in shared files.""" + overlaps = {} + + for file_path in shared_files: + if should_ignore_file(file_path): + continue + + file_a = changes_a.get(file_path) + file_b = changes_b.get(file_path) + + if not file_a or not file_b: + continue + + # Skip pure renames + if file_a.is_rename and not file_a.additions and not file_a.deletions: + continue + if file_b.is_rename and not file_b.additions and not file_b.deletions: + continue + + # Note: This mixes old-file (deletions) and new-file (additions) line numbers, + # which can cause false positives when PRs insert/remove many lines. + # Acceptable for v1 since the real merge test is the authoritative check. + file_overlaps = find_range_overlaps( + file_a.additions + file_a.deletions, + file_b.additions + file_b.deletions + ) + + if file_overlaps: + overlaps[file_path] = merge_ranges(file_overlaps) + + return overlaps + + +def find_range_overlaps( + ranges_a: list[tuple[int, int]], + ranges_b: list[tuple[int, int]] +) -> list[tuple[int, int]]: + """Find overlapping regions between two sets of ranges.""" + overlaps = [] + for range_a in ranges_a: + for range_b in ranges_b: + if ranges_overlap(range_a, range_b): + overlap_start = max(range_a[0], range_b[0]) + overlap_end = min(range_a[1], range_b[1]) + overlaps.append((overlap_start, overlap_end)) + return overlaps + + +def ranges_overlap(range_a: tuple[int, int], range_b: tuple[int, int]) -> bool: + """Check if two line ranges overlap.""" + return range_a[0] <= range_b[1] and range_b[0] <= range_a[1] + + +def merge_ranges(ranges: list[tuple[int, int]]) -> list[tuple[int, int]]: + """Merge overlapping line ranges.""" + if not ranges: + return [] + + sorted_ranges = sorted(ranges, key=lambda x: x[0]) + merged = [sorted_ranges[0]] + + for current in sorted_ranges[1:]: + last = merged[-1] + if current[0] <= last[1] + 1: + merged[-1] = (last[0], max(last[1], current[1])) + else: + merged.append(current) + + return merged + + +# ============================================================================= +# MERGE CONFLICT TESTING +# ============================================================================= + +def test_merge_conflict( + owner: str, + repo: str, + base_branch: str, + pr_a: "PullRequest", + pr_b: "PullRequest" +) -> tuple[bool, list[str], list["ConflictInfo"], str]: + """Test if merging both PRs would cause a conflict.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone repo + if not clone_repo(owner, repo, base_branch, tmpdir): + return False, [], [], None + + configure_git(tmpdir) + if not fetch_pr_branches(tmpdir, pr_a.number, pr_b.number): + # Fetch failed for one or both PRs - can't test merge + return False, [], [], None + + # Try merging PR A first + conflict_result = try_merge_pr(tmpdir, pr_a.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'pr_a_conflicts_base' + + # Commit and try merging PR B + run_git(["commit", "-m", f"Merge PR #{pr_a.number}"], cwd=tmpdir, check=False) + + conflict_result = try_merge_pr(tmpdir, pr_b.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'conflict' + + return False, [], [], None + + +def clone_repo(owner: str, repo: str, branch: str, tmpdir: str) -> bool: + """Clone the repository.""" + clone_url = f"https://github.com/{owner}/{repo}.git" + result = run_git( + ["clone", "--depth=50", "--branch", branch, clone_url, tmpdir], + check=False + ) + if result.returncode != 0: + print(f"Failed to clone: {result.stderr}", file=sys.stderr) + return False + return True + + +def configure_git(tmpdir: str): + """Configure git for commits.""" + run_git(["config", "user.email", "github-actions[bot]@users.noreply.github.com"], cwd=tmpdir, check=False) + run_git(["config", "user.name", "github-actions[bot]"], cwd=tmpdir, check=False) + + +def fetch_pr_branches(tmpdir: str, pr_a: int, pr_b: int) -> bool: + """Fetch both PR branches. Returns False if any fetch fails.""" + success = True + for pr_num in (pr_a, pr_b): + result = run_git(["fetch", "origin", f"pull/{pr_num}/head:pr-{pr_num}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{pr_num}: {result.stderr.strip()}", file=sys.stderr) + success = False + return success + + +def try_merge_pr(tmpdir: str, pr_number: int) -> Optional[tuple[list[str], list["ConflictInfo"]]]: + """Try to merge a PR. Returns conflict info if conflicts, None if success.""" + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{pr_number}"], cwd=tmpdir, check=False) + + if result.returncode == 0: + return None + + # Conflict detected + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + return conflict_files, conflict_details + + +def extract_conflict_info(tmpdir: str, stderr: str) -> tuple[list[str], list["ConflictInfo"]]: + """Extract conflict information from git status.""" + status_result = run_git(["status", "--porcelain"], cwd=tmpdir, check=False) + + status_types = { + 'UU': 'content', + 'AA': 'both_added', + 'DD': 'both_deleted', + 'DU': 'deleted_by_us', + 'UD': 'deleted_by_them', + 'AU': 'added_by_us', + 'UA': 'added_by_them', + } + + conflict_files = [] + conflict_details = [] + + for line in status_result.stdout.split("\n"): + if len(line) >= 3 and line[0:2] in status_types: + status_code = line[0:2] + file_path = line[3:].strip() + conflict_files.append(file_path) + + info = analyze_conflict_markers(file_path, tmpdir) + info.conflict_type = status_types.get(status_code, 'unknown') + conflict_details.append(info) + + # Fallback to stderr parsing + if not conflict_files and stderr: + for line in stderr.split("\n"): + if "CONFLICT" in line and ":" in line: + parts = line.split(":") + if len(parts) > 1: + file_part = parts[-1].strip() + if file_part and not file_part.startswith("Merge"): + conflict_files.append(file_part) + conflict_details.append(ConflictInfo(path=file_part)) + + return conflict_files, conflict_details + + +def analyze_conflict_markers(file_path: str, cwd: str) -> "ConflictInfo": + """Analyze a conflicted file to count conflict regions and lines.""" + info = ConflictInfo(path=file_path) + + try: + full_path = os.path.join(cwd, file_path) + with open(full_path, 'r', errors='ignore') as f: + content = f.read() + + in_conflict = False + current_conflict_lines = 0 + + for line in content.split('\n'): + if line.startswith('<<<<<<<'): + in_conflict = True + info.conflict_count += 1 + current_conflict_lines = 1 + elif line.startswith('>>>>>>>'): + in_conflict = False + current_conflict_lines += 1 + info.conflict_lines += current_conflict_lines + elif in_conflict: + current_conflict_lines += 1 + except Exception as e: + print(f"Warning: Could not analyze conflict markers in {file_path}: {e}", file=sys.stderr) + + return info + + +# ============================================================================= +# DIFF PARSING +# ============================================================================= + +def parse_diff_ranges(diff: str) -> dict[str, "ChangedFile"]: + """Parse a unified diff and extract changed line ranges per file.""" + files = {} + current_file = None + pending_rename_from = None + is_rename = False + + for line in diff.split("\n"): + # Reset rename state on new file diff header + if line.startswith("diff --git "): + is_rename = False + pending_rename_from = None + elif line.startswith("rename from "): + pending_rename_from = line[12:] + is_rename = True + elif line.startswith("rename to "): + pass # rename target is captured via "+++ b/" line + elif line.startswith("similarity index"): + is_rename = True + elif line.startswith("+++ b/"): + path = line[6:] + current_file = ChangedFile( + path=path, + additions=[], + deletions=[], + is_rename=is_rename, + old_path=pending_rename_from + ) + files[path] = current_file + pending_rename_from = None + is_rename = False + elif line.startswith("--- /dev/null"): + is_rename = False + pending_rename_from = None + elif line.startswith("@@") and current_file: + parse_hunk_header(line, current_file) + + return files + + +def parse_hunk_header(line: str, current_file: "ChangedFile"): + """Parse a diff hunk header and add ranges to the file.""" + match = re.match(r"@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@", line) + if match: + old_start = int(match.group(1)) + old_count = int(match.group(2) or 1) + new_start = int(match.group(3)) + new_count = int(match.group(4) or 1) + + if old_count > 0: + current_file.deletions.append((old_start, old_start + old_count - 1)) + if new_count > 0: + current_file.additions.append((new_start, new_start + new_count - 1)) + + +# ============================================================================= +# GITHUB API +# ============================================================================= + +def get_repo_info() -> tuple[str, str]: + """Get owner and repo name from environment or git.""" + if os.environ.get("GITHUB_REPOSITORY"): + owner, repo = os.environ["GITHUB_REPOSITORY"].split("/") + return owner, repo + + result = run_gh(["repo", "view", "--json", "owner,name"]) + data = json.loads(result.stdout) + return data["owner"]["login"], data["name"] + + +def query_open_prs(owner: str, repo: str, base_branch: str) -> list[dict]: + """Query all open PRs targeting the specified base branch.""" + prs = [] + cursor = None + + while True: + after_clause = f', after: "{cursor}"' if cursor else "" + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequests( + first: 100{after_clause}, + states: OPEN, + baseRefName: "{base_branch}", + orderBy: {{field: UPDATED_AT, direction: DESC}} + ) {{ + totalCount + edges {{ + node {{ + number + title + url + updatedAt + author {{ login }} + headRefName + baseRefName + files(first: 100) {{ + nodes {{ path }} + pageInfo {{ hasNextPage }} + }} + }} + }} + pageInfo {{ + endCursor + hasNextPage + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"]) + data = json.loads(result.stdout) + + if "errors" in data: + print(f"GraphQL errors: {data['errors']}", file=sys.stderr) + sys.exit(1) + + pr_data = data["data"]["repository"]["pullRequests"] + for edge in pr_data["edges"]: + node = edge["node"] + files_data = node["files"] + # Warn if PR has more than 100 files (API limit, we only fetch first 100) + if files_data.get("pageInfo", {}).get("hasNextPage"): + print(f"Warning: PR #{node['number']} has >100 files, overlap detection may be incomplete", file=sys.stderr) + prs.append({ + "number": node["number"], + "title": node["title"], + "url": node["url"], + "updated_at": node.get("updatedAt"), + "author": node["author"]["login"] if node["author"] else "unknown", + "head_ref": node["headRefName"], + "base_ref": node["baseRefName"], + "files": [f["path"] for f in files_data["nodes"]] + }) + + if not pr_data["pageInfo"]["hasNextPage"]: + break + cursor = pr_data["pageInfo"]["endCursor"] + + return prs + + +def get_pr_diff(pr_number: int) -> str: + """Get the diff for a PR.""" + result = run_gh(["pr", "diff", str(pr_number)]) + return result.stdout + + +def post_or_update_comment(pr_number: int, body: str): + """Post a new comment or update existing overlap detection comment.""" + if not body: + return + + marker = "## 🔍 PR Overlap Detection" + + # Find existing comment using GraphQL + owner, repo = get_repo_info() + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequest(number: {pr_number}) {{ + comments(first: 100) {{ + nodes {{ + id + body + author {{ login }} + }} + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"], check=False) + + existing_comment_id = None + if result.returncode == 0: + try: + data = json.loads(result.stdout) + comments = data.get("data", {}).get("repository", {}).get("pullRequest", {}).get("comments", {}).get("nodes", []) + for comment in comments: + if marker in comment.get("body", ""): + existing_comment_id = comment["id"] + break + except Exception as e: + print(f"Warning: Could not search for existing comment: {e}", file=sys.stderr) + + if existing_comment_id: + # Update existing comment using GraphQL mutation + # Use json.dumps for proper escaping of all special characters + escaped_body = json.dumps(body)[1:-1] # Strip outer quotes added by json.dumps + mutation = f''' + mutation {{ + updateIssueComment(input: {{id: "{existing_comment_id}", body: "{escaped_body}"}}) {{ + issueComment {{ id }} + }} + }} + ''' + result = run_gh(["api", "graphql", "-f", f"query={mutation}"], check=False) + if result.returncode == 0: + print(f"Updated existing overlap comment") + else: + # Fallback to posting new comment + print(f"Failed to update comment, posting new one: {result.stderr}", file=sys.stderr) + run_gh(["pr", "comment", str(pr_number), "--body", body]) + else: + # Post new comment + run_gh(["pr", "comment", str(pr_number), "--body", body]) + + +def send_discord_notification(webhook_url: str, pr: "PullRequest", overlaps: list["Overlap"]): + """Send a Discord notification about significant overlaps.""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if not conflicts: + return + + # Discord limits: max 25 fields, max 1024 chars per field value + fields = [] + for o in conflicts[:25]: + other = o.pr_b if o.pr_a.number == pr.number else o.pr_a + # Build value string with truncation to stay under 1024 chars + file_list = o.conflict_files[:3] + files_str = f"Files: `{'`, `'.join(file_list)}`" + if len(o.conflict_files) > 3: + files_str += f" (+{len(o.conflict_files) - 3} more)" + value = f"[{other.title[:100]}]({other.url})\n{files_str}" + # Truncate if still too long + if len(value) > 1024: + value = value[:1020] + "..." + fields.append({ + "name": f"Conflicts with #{other.number}", + "value": value, + "inline": False + }) + + embed = { + "title": f"⚠️ PR #{pr.number} has merge conflicts", + "description": f"[{pr.title}]({pr.url})", + "color": 0xFF0000, + "fields": fields + } + + if len(conflicts) > 25: + embed["footer"] = {"text": f"... and {len(conflicts) - 25} more conflicts"} + + try: + subprocess.run( + ["curl", "-X", "POST", "-H", "Content-Type: application/json", + "--max-time", "10", + "-d", json.dumps({"embeds": [embed]}), webhook_url], + capture_output=True, + timeout=15 + ) + except subprocess.TimeoutExpired: + print("Warning: Discord webhook timed out", file=sys.stderr) + + +# ============================================================================= +# UTILITIES +# ============================================================================= + +def run_gh(args: list[str], check: bool = True) -> subprocess.CompletedProcess: + """Run a gh CLI command.""" + result = subprocess.run( + ["gh"] + args, + capture_output=True, + text=True, + check=False + ) + if check and result.returncode != 0: + print(f"Error running gh {' '.join(args)}: {result.stderr}", file=sys.stderr) + sys.exit(1) + return result + + +def run_git(args: list[str], cwd: str = None, check: bool = True) -> subprocess.CompletedProcess: + """Run a git command.""" + result = subprocess.run( + ["git"] + args, + capture_output=True, + text=True, + cwd=cwd, + check=False + ) + if check and result.returncode != 0: + print(f"Error running git {' '.join(args)}: {result.stderr}", file=sys.stderr) + return result + + +def should_ignore_file(path: str) -> bool: + """Check if a file should be ignored for overlap detection.""" + if path in IGNORE_FILES: + return True + basename = path.split("/")[-1] + return basename in IGNORE_FILES + + +def find_common_prefix(paths: list[str]) -> str: + """Find the common directory prefix of a list of file paths.""" + if not paths: + return "" + if len(paths) == 1: + parts = paths[0].rsplit('/', 1) + return parts[0] + '/' if len(parts) > 1 else "" + + split_paths = [p.split('/') for p in paths] + common = [] + for parts in zip(*split_paths): + if len(set(parts)) == 1: + common.append(parts[0]) + else: + break + + return '/'.join(common) + '/' if common else "" + + +def format_relative_time(iso_timestamp: str) -> str: + """Format an ISO timestamp as relative time.""" + if not iso_timestamp: + return "" + + from datetime import datetime, timezone + try: + dt = datetime.fromisoformat(iso_timestamp.replace('Z', '+00:00')) + now = datetime.now(timezone.utc) + diff = now - dt + + seconds = diff.total_seconds() + if seconds < 60: + return "just now" + elif seconds < 3600: + return f"{int(seconds / 60)}m ago" + elif seconds < 86400: + return f"{int(seconds / 3600)}h ago" + else: + return f"{int(seconds / 86400)}d ago" + except Exception as e: + print(f"Warning: Could not format relative time: {e}", file=sys.stderr) + return "" + + +# ============================================================================= +# DATA CLASSES +# ============================================================================= + +@dataclass +class ChangedFile: + """Represents a file changed in a PR.""" + path: str + additions: list[tuple[int, int]] + deletions: list[tuple[int, int]] + is_rename: bool = False + old_path: str = None + + +@dataclass +class PullRequest: + """Represents a pull request.""" + number: int + title: str + author: str + url: str + head_ref: str + base_ref: str + files: list[str] + changed_ranges: dict[str, ChangedFile] + updated_at: str = None + + +@dataclass +class ConflictInfo: + """Info about a single conflicting file.""" + path: str + conflict_count: int = 0 + conflict_lines: int = 0 + conflict_type: str = "content" + + +@dataclass +class Overlap: + """Represents an overlap between two PRs.""" + pr_a: PullRequest + pr_b: PullRequest + overlapping_files: list[str] + line_overlaps: dict[str, list[tuple[int, int]]] + has_merge_conflict: bool = False + conflict_files: list[str] = None + conflict_details: list[ConflictInfo] = None + conflict_type: str = None + + def __post_init__(self): + if self.conflict_files is None: + self.conflict_files = [] + if self.conflict_details is None: + self.conflict_details = [] + + +# ============================================================================= +# CONSTANTS +# ============================================================================= + +IGNORE_FILES = { + "autogpt_platform/frontend/src/app/api/openapi.json", + "poetry.lock", + "pnpm-lock.yaml", + "package-lock.json", + "yarn.lock", +} + + +# ============================================================================= +# ENTRY POINT +# ============================================================================= + +if __name__ == "__main__": + main() diff --git a/.github/workflows/claude-ci-failure-auto-fix.yml b/.github/workflows/claude-ci-failure-auto-fix.yml index ab07c8ae10..dbca6dc3f3 100644 --- a/.github/workflows/claude-ci-failure-auto-fix.yml +++ b/.github/workflows/claude-ci-failure-auto-fix.yml @@ -40,6 +40,48 @@ jobs: git checkout -b "$BRANCH_NAME" echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT + # Backend Python/Poetry setup (so Claude can run linting/tests) + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Set up Python dependency cache + uses: actions/cache@v5 + with: + path: ~/.cache/pypoetry + key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} + + - name: Install Poetry + run: | + cd autogpt_platform/backend + HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry) + curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install Python dependencies + working-directory: autogpt_platform/backend + run: poetry install + + - name: Generate Prisma Client + working-directory: autogpt_platform/backend + run: poetry run prisma generate && poetry run gen-prisma-stub + + # Frontend Node.js/pnpm setup (so Claude can run linting/tests) + - name: Enable corepack + run: corepack enable + + - name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: "22" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml + + - name: Install JavaScript dependencies + working-directory: autogpt_platform/frontend + run: pnpm install --frozen-lockfile + - name: Get CI failure details id: failure_details uses: actions/github-script@v8 diff --git a/.github/workflows/claude-dependabot.yml b/.github/workflows/claude-dependabot.yml index da37df6de7..274c6d2cab 100644 --- a/.github/workflows/claude-dependabot.yml +++ b/.github/workflows/claude-dependabot.yml @@ -77,27 +77,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index ee901fe5d4..8b8260af6b 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -93,27 +93,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 966243323c..ff535f8496 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -62,7 +62,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} @@ -93,6 +93,6 @@ jobs: exit 1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/docs-claude-review.yml b/.github/workflows/docs-claude-review.yml index ca2788b387..19d5dd667b 100644 --- a/.github/workflows/docs-claude-review.yml +++ b/.github/workflows/docs-claude-review.yml @@ -7,6 +7,10 @@ on: - "docs/integrations/**" - "autogpt_platform/backend/backend/blocks/**" +concurrency: + group: claude-docs-review-${{ github.event.pull_request.number }} + cancel-in-progress: true + jobs: claude-review: # Only run for PRs from members/collaborators @@ -91,5 +95,35 @@ jobs: 3. Read corresponding documentation files to verify accuracy 4. Provide your feedback as a PR comment + ## IMPORTANT: Comment Marker + Start your PR comment with exactly this HTML comment marker on its own line: + + + This marker is used to identify and replace your comment on subsequent runs. + Be constructive and specific. If everything looks good, say so! If there are issues, explain what's wrong and suggest how to fix it. + + - name: Delete old Claude review comments + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + # Get all comment IDs with our marker, sorted by creation date (oldest first) + COMMENT_IDS=$(gh api \ + repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \ + --jq '[.[] | select(.body | contains(""))] | sort_by(.created_at) | .[].id') + + # Count comments + COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true) + + if [ "$COMMENT_COUNT" -gt 1 ]; then + # Delete all but the last (newest) comment + echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do + if [ -n "$COMMENT_ID" ]; then + echo "Deleting old review comment: $COMMENT_ID" + gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID + fi + done + else + echo "No old review comments to clean up" + fi diff --git a/.github/workflows/platform-frontend-ci.yml b/.github/workflows/platform-frontend-ci.yml index 6410daae9f..4bf8a2b80c 100644 --- a/.github/workflows/platform-frontend-ci.yml +++ b/.github/workflows/platform-frontend-ci.yml @@ -26,7 +26,6 @@ jobs: setup: runs-on: ubuntu-latest outputs: - cache-key: ${{ steps.cache-key.outputs.key }} components-changed: ${{ steps.filter.outputs.components }} steps: @@ -41,28 +40,17 @@ jobs: components: - 'autogpt_platform/frontend/src/components/**' - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Generate cache key - id: cache-key - run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT - - - name: Cache dependencies - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ steps.cache-key.outputs.key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - - name: Install dependencies + - name: Install dependencies to populate cache run: pnpm install --frozen-lockfile lint: @@ -73,22 +61,15 @@ jobs: - name: Checkout repository uses: actions/checkout@v6 - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Restore dependencies cache - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install dependencies run: pnpm install --frozen-lockfile @@ -111,22 +92,15 @@ jobs: with: fetch-depth: 0 - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Restore dependencies cache - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install dependencies run: pnpm install --frozen-lockfile @@ -141,10 +115,8 @@ jobs: exitOnceUploaded: true e2e_test: + name: end-to-end tests runs-on: big-boi - needs: setup - strategy: - fail-fast: false steps: - name: Checkout repository @@ -152,19 +124,11 @@ jobs: with: submodules: recursive - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - - name: Enable corepack - run: corepack enable - - - name: Copy default supabase .env + - name: Set up Platform - Copy default supabase .env run: | cp ../.env.default ../.env - - name: Copy backend .env and set OpenAI API key + - name: Set up Platform - Copy backend .env and set OpenAI API key run: | cp ../backend/.env.default ../backend/.env echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env @@ -172,77 +136,125 @@ jobs: # Used by E2E test data script to generate embeddings for approved store agents OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - - name: Set up Docker Buildx + - name: Set up Platform - Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + driver: docker-container + driver-opts: network=host - - name: Cache Docker layers + - name: Set up Platform - Expose GHA cache to docker buildx CLI + uses: crazy-max/ghaction-github-runtime@v3 + + - name: Set up Platform - Build Docker images (with cache) + working-directory: autogpt_platform + run: | + pip install pyyaml + + # Resolve extends and generate a flat compose file that bake can understand + docker compose -f docker-compose.yml config > docker-compose.resolved.yml + + # Add cache configuration to the resolved compose file + python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \ + --source docker-compose.resolved.yml \ + --cache-from "type=gha" \ + --cache-to "type=gha,mode=max" \ + --backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \ + --frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \ + --git-ref "${{ github.ref }}" + + # Build with bake using the resolved compose file (now includes cache config) + docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load + env: + NEXT_PUBLIC_PW_TEST: true + + - name: Set up tests - Cache E2E test data + id: e2e-data-cache uses: actions/cache@v5 with: - path: /tmp/.buildx-cache - key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }} - restore-keys: | - ${{ runner.os }}-buildx-frontend-test- + path: /tmp/e2e_test_data.sql + key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }} - - name: Run docker compose + - name: Set up Platform - Start Supabase DB + Auth run: | - NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d + docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build + echo "Waiting for database to be ready..." + timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' + echo "Waiting for auth service to be ready..." + timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..." + + - name: Set up Platform - Run migrations + run: | + echo "Running migrations..." + docker compose -f ../docker-compose.resolved.yml run --rm migrate + echo "✅ Migrations completed" env: - DOCKER_BUILDKIT: 1 - BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache - BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max + NEXT_PUBLIC_PW_TEST: true - - name: Move cache + - name: Set up tests - Load cached E2E test data + if: steps.e2e-data-cache.outputs.cache-hit == 'true' run: | - rm -rf /tmp/.buildx-cache - if [ -d "/tmp/.buildx-cache-new" ]; then - mv /tmp/.buildx-cache-new /tmp/.buildx-cache - fi + echo "✅ Found cached E2E test data, restoring..." + { + echo "SET session_replication_role = 'replica';" + cat /tmp/e2e_test_data.sql + echo "SET session_replication_role = 'origin';" + } | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b + # Refresh materialized views after restore + docker compose -f ../docker-compose.resolved.yml exec -T db \ + psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true - - name: Wait for services to be ready + echo "✅ E2E test data restored from cache" + + - name: Set up Platform - Start (all other services) run: | + docker compose -f ../docker-compose.resolved.yml up -d --no-build echo "Waiting for rest_server to be ready..." timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..." - echo "Waiting for database to be ready..." - timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..." + env: + NEXT_PUBLIC_PW_TEST: true - - name: Create E2E test data + - name: Set up tests - Create E2E test data + if: steps.e2e-data-cache.outputs.cache-hit != 'true' run: | echo "Creating E2E test data..." - # First try to run the script from inside the container - if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then - echo "✅ Found e2e_test_data.py in container, running it..." - docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || { - echo "❌ E2E test data creation failed!" - docker compose -f ../docker-compose.yml logs --tail=50 rest_server - exit 1 - } - else - echo "⚠️ e2e_test_data.py not found in container, copying and running..." - # Copy the script into the container and run it - docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || { - echo "❌ Failed to copy script to container" - exit 1 - } - docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || { - echo "❌ E2E test data creation failed!" - docker compose -f ../docker-compose.yml logs --tail=50 rest_server - exit 1 - } - fi + docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py + docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || { + echo "❌ E2E test data creation failed!" + docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server + exit 1 + } - - name: Restore dependencies cache - uses: actions/cache@v5 + # Dump auth.users + platform schema for cache (two separate dumps) + echo "Dumping database for cache..." + { + docker compose -f ../docker-compose.resolved.yml exec -T db \ + pg_dump -U postgres --data-only --column-inserts \ + --table='auth.users' postgres + docker compose -f ../docker-compose.resolved.yml exec -T db \ + pg_dump -U postgres --data-only --column-inserts \ + --schema=platform \ + --exclude-table='platform._prisma_migrations' \ + --exclude-table='platform.apscheduler_jobs' \ + --exclude-table='platform.apscheduler_jobs_batched_notifications' \ + postgres + } > /tmp/e2e_test_data.sql + + echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)" + + - name: Set up tests - Enable corepack + run: corepack enable + + - name: Set up tests - Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - - name: Install dependencies + - name: Set up tests - Install dependencies run: pnpm install --frozen-lockfile - - name: Install Browser 'chromium' + - name: Set up tests - Install browser 'chromium' run: pnpm playwright install --with-deps chromium - name: Run Playwright tests @@ -269,7 +281,7 @@ jobs: - name: Print Final Docker Compose logs if: always() - run: docker compose -f ../docker-compose.yml logs + run: docker compose -f ../docker-compose.resolved.yml logs integration_test: runs-on: ubuntu-latest @@ -281,22 +293,15 @@ jobs: with: submodules: recursive - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Restore dependencies cache - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install dependencies run: pnpm install --frozen-lockfile diff --git a/.github/workflows/pr-overlap-check.yml b/.github/workflows/pr-overlap-check.yml new file mode 100644 index 0000000000..c53f56321b --- /dev/null +++ b/.github/workflows/pr-overlap-check.yml @@ -0,0 +1,39 @@ +name: PR Overlap Detection + +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - dev + - master + +permissions: + contents: read + pull-requests: write + +jobs: + check-overlaps: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Need full history for merge testing + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Configure git + run: | + git config user.email "github-actions[bot]@users.noreply.github.com" + git config user.name "github-actions[bot]" + + - name: Run overlap detection + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # Always succeed - this check informs contributors, it shouldn't block merging + continue-on-error: true + run: | + python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }} diff --git a/.github/workflows/scripts/docker-ci-fix-compose-build-cache.py b/.github/workflows/scripts/docker-ci-fix-compose-build-cache.py new file mode 100644 index 0000000000..33693fc739 --- /dev/null +++ b/.github/workflows/scripts/docker-ci-fix-compose-build-cache.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Add cache configuration to a resolved docker-compose file for all services +that have a build key, and ensure image names match what docker compose expects. +""" + +import argparse + +import yaml + + +DEFAULT_BRANCH = "dev" +CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"] + + +def main(): + parser = argparse.ArgumentParser( + description="Add cache config to a resolved compose file" + ) + parser.add_argument( + "--source", + required=True, + help="Source compose file to read (should be output of `docker compose config`)", + ) + parser.add_argument( + "--cache-from", + default="type=gha", + help="Cache source configuration", + ) + parser.add_argument( + "--cache-to", + default="type=gha,mode=max", + help="Cache destination configuration", + ) + for component in CACHE_BUILDS_FOR_COMPONENTS: + parser.add_argument( + f"--{component}-hash", + default="", + help=f"Hash for {component} cache scope (e.g., from hashFiles())", + ) + parser.add_argument( + "--git-ref", + default="", + help="Git ref for branch-based cache scope (e.g., refs/heads/master)", + ) + args = parser.parse_args() + + # Normalize git ref to a safe scope name (e.g., refs/heads/master -> master) + git_ref_scope = "" + if args.git_ref: + git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-") + + with open(args.source, "r") as f: + compose = yaml.safe_load(f) + + # Get project name from compose file or default + project_name = compose.get("name", "autogpt_platform") + + def get_image_name(dockerfile: str, target: str) -> str: + """Generate image name based on Dockerfile folder and build target.""" + dockerfile_parts = dockerfile.replace("\\", "/").split("/") + if len(dockerfile_parts) >= 2: + folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend" + else: + folder_name = "app" + return f"{project_name}-{folder_name}:{target}" + + def get_build_key(dockerfile: str, target: str) -> str: + """Generate a unique key for a Dockerfile+target combination.""" + return f"{dockerfile}:{target}" + + def get_component(dockerfile: str) -> str | None: + """Get component name (frontend/backend) from dockerfile path.""" + for component in CACHE_BUILDS_FOR_COMPONENTS: + if component in dockerfile: + return component + return None + + # First pass: collect all services with build configs and identify duplicates + # Track which (dockerfile, target) combinations we've seen + build_key_to_first_service: dict[str, str] = {} + services_to_build: list[str] = [] + services_to_dedupe: list[str] = [] + + for service_name, service_config in compose.get("services", {}).items(): + if "build" not in service_config: + continue + + build_config = service_config["build"] + dockerfile = build_config.get("dockerfile", "Dockerfile") + target = build_config.get("target", "default") + build_key = get_build_key(dockerfile, target) + + if build_key not in build_key_to_first_service: + # First service with this build config - it will do the actual build + build_key_to_first_service[build_key] = service_name + services_to_build.append(service_name) + else: + # Duplicate - will just use the image from the first service + services_to_dedupe.append(service_name) + + # Second pass: configure builds and deduplicate + modified_services = [] + for service_name, service_config in compose.get("services", {}).items(): + if "build" not in service_config: + continue + + build_config = service_config["build"] + dockerfile = build_config.get("dockerfile", "Dockerfile") + target = build_config.get("target", "latest") + image_name = get_image_name(dockerfile, target) + + # Set image name for all services (needed for both builders and deduped) + service_config["image"] = image_name + + if service_name in services_to_dedupe: + # Remove build config - this service will use the pre-built image + del service_config["build"] + continue + + # This service will do the actual build - add cache config + cache_from_list = [] + cache_to_list = [] + + component = get_component(dockerfile) + if not component: + # Skip services that don't clearly match frontend/backend + continue + + # Get the hash for this component + component_hash = getattr(args, f"{component}_hash") + + # Scope format: platform-{component}-{target}-{hash|ref} + # Example: platform-backend-server-abc123 + + if "type=gha" in args.cache_from: + # 1. Primary: exact hash match (most specific) + if component_hash: + hash_scope = f"platform-{component}-{target}-{component_hash}" + cache_from_list.append(f"{args.cache_from},scope={hash_scope}") + + # 2. Fallback: branch-based cache + if git_ref_scope: + ref_scope = f"platform-{component}-{target}-{git_ref_scope}" + cache_from_list.append(f"{args.cache_from},scope={ref_scope}") + + # 3. Fallback: dev branch cache (for PRs/feature branches) + if git_ref_scope and git_ref_scope != DEFAULT_BRANCH: + master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}" + cache_from_list.append(f"{args.cache_from},scope={master_scope}") + + if "type=gha" in args.cache_to: + # Write to both hash-based and branch-based scopes + if component_hash: + hash_scope = f"platform-{component}-{target}-{component_hash}" + cache_to_list.append(f"{args.cache_to},scope={hash_scope}") + + if git_ref_scope: + ref_scope = f"platform-{component}-{target}-{git_ref_scope}" + cache_to_list.append(f"{args.cache_to},scope={ref_scope}") + + # Ensure we have at least one cache source/target + if not cache_from_list: + cache_from_list.append(args.cache_from) + if not cache_to_list: + cache_to_list.append(args.cache_to) + + build_config["cache_from"] = cache_from_list + build_config["cache_to"] = cache_to_list + modified_services.append(service_name) + + # Write back to the same file + with open(args.source, "w") as f: + yaml.dump(compose, f, default_flow_style=False, sort_keys=False) + + print(f"Added cache config to {len(modified_services)} services in {args.source}:") + for svc in modified_services: + svc_config = compose["services"][svc] + build_cfg = svc_config.get("build", {}) + cache_from_list = build_cfg.get("cache_from", ["none"]) + cache_to_list = build_cfg.get("cache_to", ["none"]) + print(f" - {svc}") + print(f" image: {svc_config.get('image', 'N/A')}") + print(f" cache_from: {cache_from_list}") + print(f" cache_to: {cache_to_list}") + if services_to_dedupe: + print( + f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):" + ) + for svc in services_to_dedupe: + print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}") + + +if __name__ == "__main__": + main() diff --git a/autogpt_platform/CLAUDE.md b/autogpt_platform/CLAUDE.md index 62adbdaefa..021b7c27e4 100644 --- a/autogpt_platform/CLAUDE.md +++ b/autogpt_platform/CLAUDE.md @@ -45,6 +45,11 @@ AutoGPT Platform is a monorepo containing: - Backend/Frontend services use YAML anchors for consistent configuration - Supabase services (`db/docker/docker-compose.yml`) follow the same pattern +### Branching Strategy + +- **`dev`** is the main development branch. All PRs should target `dev`. +- **`master`** is the production branch. Only used for production releases. + ### Creating Pull Requests - Create the PR against the `dev` branch of the repository. diff --git a/autogpt_platform/autogpt_libs/poetry.lock b/autogpt_platform/autogpt_libs/poetry.lock index 0a421dda31..e1d599360e 100644 --- a/autogpt_platform/autogpt_libs/poetry.lock +++ b/autogpt_platform/autogpt_libs/poetry.lock @@ -448,61 +448,61 @@ toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cryptography" -version = "46.0.4" +version = "46.0.5" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = false python-versions = "!=3.9.0,!=3.9.1,>=3.8" groups = ["main"] files = [ - {file = "cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616"}, - {file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0"}, - {file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0"}, - {file = "cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5"}, - {file = "cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b"}, - {file = "cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f"}, - {file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82"}, - {file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c"}, - {file = "cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061"}, - {file = "cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7"}, - {file = "cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019"}, - {file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4"}, - {file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b"}, - {file = "cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc"}, - {file = "cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3"}, - {file = "cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59"}, + {file = "cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731"}, + {file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82"}, + {file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1"}, + {file = "cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48"}, + {file = "cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4"}, + {file = "cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663"}, + {file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826"}, + {file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d"}, + {file = "cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a"}, + {file = "cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4"}, + {file = "cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c"}, + {file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4"}, + {file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9"}, + {file = "cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72"}, + {file = "cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"}, + {file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"}, ] [package.dependencies] @@ -516,7 +516,7 @@ nox = ["nox[uv] (>=2024.4.15)"] pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"] sdist = ["build (>=1.0.0)"] ssh = ["bcrypt (>=3.1.5)"] -test = ["certifi (>=2024)", "cryptography-vectors (==46.0.4)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] +test = ["certifi (>=2024)", "cryptography-vectors (==46.0.5)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] test-randomorder = ["pytest-randomly"] [[package]] @@ -570,24 +570,25 @@ tests = ["coverage", "coveralls", "dill", "mock", "nose"] [[package]] name = "fastapi" -version = "0.128.0" +version = "0.128.7" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d"}, - {file = "fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a"}, + {file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"}, + {file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"}, ] [package.dependencies] annotated-doc = ">=0.0.2" pydantic = ">=2.7.0" -starlette = ">=0.40.0,<0.51.0" +starlette = ">=0.40.0,<1.0.0" typing-extensions = ">=4.8.0" +typing-inspection = ">=0.4.2" [package.extras] -all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.9.3)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=5.8.0)", "uvicorn[standard] (>=0.12.0)"] standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] @@ -1062,14 +1063,14 @@ urllib3 = ">=1.26.0,<3" [[package]] name = "launchdarkly-server-sdk" -version = "9.14.1" +version = "9.15.0" description = "LaunchDarkly SDK for Python" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"}, - {file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"}, + {file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"}, + {file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"}, ] [package.dependencies] @@ -1478,14 +1479,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"] [[package]] name = "postgrest" -version = "2.27.2" +version = "2.28.0" description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "postgrest-2.27.2-py3-none-any.whl", hash = "sha256:1666fef3de05ca097a314433dd5ae2f2d71c613cb7b233d0f468c4ffe37277da"}, - {file = "postgrest-2.27.2.tar.gz", hash = "sha256:55407d530b5af3d64e883a71fec1f345d369958f723ce4a8ab0b7d169e313242"}, + {file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"}, + {file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"}, ] [package.dependencies] @@ -2248,14 +2249,14 @@ cli = ["click (>=5.0)"] [[package]] name = "realtime" -version = "2.27.2" +version = "2.28.0" description = "" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f"}, - {file = "realtime-2.27.2.tar.gz", hash = "sha256:b960a90294d2cea1b3f1275ecb89204304728e08fff1c393cc1b3150739556b3"}, + {file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"}, + {file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"}, ] [package.dependencies] @@ -2436,14 +2437,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart [[package]] name = "storage3" -version = "2.27.2" +version = "2.28.0" description = "Supabase Storage client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "storage3-2.27.2-py3-none-any.whl", hash = "sha256:e6f16e7a260729e7b1f46e9bf61746805a02e30f5e419ee1291007c432e3ec63"}, - {file = "storage3-2.27.2.tar.gz", hash = "sha256:cb4807b7f86b4bb1272ac6fdd2f3cfd8ba577297046fa5f88557425200275af5"}, + {file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"}, + {file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"}, ] [package.dependencies] @@ -2487,35 +2488,35 @@ python-dateutil = ">=2.6.0" [[package]] name = "supabase" -version = "2.27.2" +version = "2.28.0" description = "Supabase client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase-2.27.2-py3-none-any.whl", hash = "sha256:d4dce00b3a418ee578017ec577c0e5be47a9a636355009c76f20ed2faa15bc54"}, - {file = "supabase-2.27.2.tar.gz", hash = "sha256:2aed40e4f3454438822442a1e94a47be6694c2c70392e7ae99b51a226d4293f7"}, + {file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"}, + {file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"}, ] [package.dependencies] httpx = ">=0.26,<0.29" -postgrest = "2.27.2" -realtime = "2.27.2" -storage3 = "2.27.2" -supabase-auth = "2.27.2" -supabase-functions = "2.27.2" +postgrest = "2.28.0" +realtime = "2.28.0" +storage3 = "2.28.0" +supabase-auth = "2.28.0" +supabase-functions = "2.28.0" yarl = ">=1.22.0" [[package]] name = "supabase-auth" -version = "2.27.2" +version = "2.28.0" description = "Python Client Library for Supabase Auth" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_auth-2.27.2-py3-none-any.whl", hash = "sha256:78ec25b11314d0a9527a7205f3b1c72560dccdc11b38392f80297ef98664ee91"}, - {file = "supabase_auth-2.27.2.tar.gz", hash = "sha256:0f5bcc79b3677cb42e9d321f3c559070cfa40d6a29a67672cc8382fb7dc2fe97"}, + {file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"}, + {file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"}, ] [package.dependencies] @@ -2525,14 +2526,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]} [[package]] name = "supabase-functions" -version = "2.27.2" +version = "2.28.0" description = "Library for Supabase Functions" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_functions-2.27.2-py3-none-any.whl", hash = "sha256:db480efc669d0bca07605b9b6f167312af43121adcc842a111f79bea416ef754"}, - {file = "supabase_functions-2.27.2.tar.gz", hash = "sha256:d0c8266207a94371cb3fd35ad3c7f025b78a97cf026861e04ccd35ac1775f80b"}, + {file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"}, + {file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"}, ] [package.dependencies] @@ -2911,4 +2912,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d" +content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb" diff --git a/autogpt_platform/autogpt_libs/pyproject.toml b/autogpt_platform/autogpt_libs/pyproject.toml index 8deb4d2169..2cfa742922 100644 --- a/autogpt_platform/autogpt_libs/pyproject.toml +++ b/autogpt_platform/autogpt_libs/pyproject.toml @@ -11,14 +11,14 @@ python = ">=3.10,<4.0" colorama = "^0.4.6" cryptography = "^46.0" expiringdict = "^1.2.2" -fastapi = "^0.128.0" +fastapi = "^0.128.7" google-cloud-logging = "^3.13.0" -launchdarkly-server-sdk = "^9.14.1" +launchdarkly-server-sdk = "^9.15.0" pydantic = "^2.12.5" pydantic-settings = "^2.12.0" pyjwt = { version = "^2.11.0", extras = ["crypto"] } redis = "^6.2.0" -supabase = "^2.27.2" +supabase = "^2.28.0" uvicorn = "^0.40.0" [tool.poetry.group.dev.dependencies] diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index fa52ba812a..2711bd2df9 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -104,6 +104,12 @@ TWITTER_CLIENT_SECRET= # Make a new workspace for your OAuth APP -- trust me # https://linear.app/settings/api/applications/new # Callback URL: http://localhost:3000/auth/integrations/oauth_callback +LINEAR_API_KEY= +# Linear project and team IDs for the feature request tracker. +# Find these in your Linear workspace URL: linear.app//project/ +# and in team settings. Used by the chat copilot to file and search feature requests. +LINEAR_FEATURE_REQUEST_PROJECT_ID= +LINEAR_FEATURE_REQUEST_TEAM_ID= LINEAR_CLIENT_ID= LINEAR_CLIENT_SECRET= diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index 9bd455e490..05a8d4858b 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -1,3 +1,5 @@ +# ============================ DEPENDENCY BUILDER ============================ # + FROM debian:13-slim AS builder # Set environment variables @@ -51,7 +53,9 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti COPY autogpt_platform/backend/gen_prisma_types_stub.py ./ RUN poetry run prisma generate && poetry run gen-prisma-stub -FROM debian:13-slim AS server_dependencies +# ============================== BACKEND SERVER ============================== # + +FROM debian:13-slim AS server WORKDIR /app @@ -62,16 +66,21 @@ ENV POETRY_HOME=/opt/poetry \ DEBIAN_FRONTEND=noninteractive ENV PATH=/opt/poetry/bin:$PATH -# Install Python, FFmpeg, and ImageMagick (required for video processing blocks) -RUN apt-get update && apt-get install -y \ +# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use. +# bubblewrap provides OS-level sandbox (whitelist-only FS + no network) +# for the bash_exec MCP tool. +# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc. +RUN apt-get update && apt-get install -y --no-install-recommends \ python3.13 \ python3-pip \ ffmpeg \ imagemagick \ + jq \ + ripgrep \ + tree \ + bubblewrap \ && rm -rf /var/lib/apt/lists/* -# Copy only necessary files from builder -COPY --from=builder /app /app COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3* COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry # Copy Node.js installation for Prisma @@ -81,30 +90,54 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm COPY --from=builder /usr/bin/npx /usr/bin/npx COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries -ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH" - -RUN mkdir -p /app/autogpt_platform/autogpt_libs -RUN mkdir -p /app/autogpt_platform/backend - -COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs - -COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/ - WORKDIR /app/autogpt_platform/backend -FROM server_dependencies AS migrate +# Copy only the .venv from builder (not the entire /app directory) +# The .venv includes the generated Prisma client +COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv +ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH" -# Migration stage only needs schema and migrations - much lighter than full backend -COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/ -COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py -COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations +# Copy dependency files + autogpt_libs (path dependency) +COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs +COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./ -FROM server_dependencies AS server - -COPY autogpt_platform/backend /app/autogpt_platform/backend +# Copy backend code + docs (for Copilot docs search) +COPY autogpt_platform/backend ./ COPY docs /app/docs RUN poetry install --no-ansi --only-root ENV PORT=8000 CMD ["poetry", "run", "rest"] + +# =============================== DB MIGRATOR =============================== # + +# Lightweight migrate stage - only needs Prisma CLI, not full Python environment +FROM debian:13-slim AS migrate + +WORKDIR /app/autogpt_platform/backend + +ENV DEBIAN_FRONTEND=noninteractive + +# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.13 \ + python3-pip \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Copy Node.js from builder (needed for Prisma CLI) +COPY --from=builder /usr/bin/node /usr/bin/node +COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules +COPY --from=builder /usr/bin/npm /usr/bin/npm + +# Copy Prisma binaries +COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries + +# Install prisma-client-py directly (much smaller than copying full venv) +RUN pip3 install prisma>=0.15.0 --break-system-packages + +COPY autogpt_platform/backend/schema.prisma ./ +COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py +COPY autogpt_platform/backend/gen_prisma_types_stub.py ./ +COPY autogpt_platform/backend/migrations ./migrations diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index 808692f97f..04bbe8e60d 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -27,12 +27,11 @@ class ChatConfig(BaseSettings): session_ttl: int = Field(default=43200, description="Session TTL in seconds") # Streaming Configuration - max_context_messages: int = Field( - default=50, ge=1, le=200, description="Maximum context messages" - ) - stream_timeout: int = Field(default=300, description="Stream timeout in seconds") - max_retries: int = Field(default=3, description="Maximum number of retries") + max_retries: int = Field( + default=3, + description="Max retries for fallback path (SDK handles retries internally)", + ) max_agent_runs: int = Field(default=30, description="Maximum number of agent runs") max_agent_schedules: int = Field( default=30, description="Maximum number of agent schedules" @@ -93,6 +92,31 @@ class ChatConfig(BaseSettings): description="Name of the prompt in Langfuse to fetch", ) + # Claude Agent SDK Configuration + use_claude_agent_sdk: bool = Field( + default=True, + description="Use Claude Agent SDK for chat completions", + ) + claude_agent_model: str | None = Field( + default=None, + description="Model for the Claude Agent SDK path. If None, derives from " + "the `model` field by stripping the OpenRouter provider prefix.", + ) + claude_agent_max_buffer_size: int = Field( + default=10 * 1024 * 1024, # 10MB (default SDK is 1MB) + description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. " + "Increase if tool outputs exceed the limit.", + ) + claude_agent_max_subtasks: int = Field( + default=10, + description="Max number of sub-agent Tasks the SDK can spawn per session.", + ) + claude_agent_use_resume: bool = Field( + default=True, + description="Use --resume for multi-turn conversations instead of " + "history compression. Falls back to compression when unavailable.", + ) + # Extended thinking configuration for Claude models thinking_enabled: bool = Field( default=True, @@ -138,6 +162,17 @@ class ChatConfig(BaseSettings): v = os.getenv("CHAT_INTERNAL_API_KEY") return v + @field_validator("use_claude_agent_sdk", mode="before") + @classmethod + def get_use_claude_agent_sdk(cls, v): + """Get use_claude_agent_sdk from environment if not provided.""" + # Check environment variable - default to True if not set + env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower() + if env_val: + return env_val in ("true", "1", "yes", "on") + # Default to True (SDK enabled by default) + return True if v is None else v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/api/features/chat/model.py index 35418f174f..30ac27aece 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/api/features/chat/model.py @@ -334,9 +334,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None: try: session = ChatSession.model_validate_json(raw_session) logger.info( - f"Loading session {session_id} from cache: " - f"message_count={len(session.messages)}, " - f"roles={[m.role for m in session.messages]}" + f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, " + f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles ) return session except Exception as e: @@ -378,11 +377,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None: return None messages = prisma_session.Messages - logger.info( - f"Loading session {session_id} from DB: " - f"has_messages={messages is not None}, " - f"message_count={len(messages) if messages else 0}, " - f"roles={[m.role for m in messages] if messages else []}" + logger.debug( + f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, " + f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles ) return ChatSession.from_db(prisma_session, messages) @@ -433,10 +430,9 @@ async def _save_session_to_db( "function_call": msg.function_call, } ) - logger.info( - f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: " - f"roles={[m['role'] for m in messages_data]}, " - f"start_sequence={existing_message_count}" + logger.debug( + f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, " + f"roles={[m['role'] for m in messages_data]}" ) await chat_db.add_chat_messages_batch( session_id=session.session_id, @@ -476,7 +472,7 @@ async def get_chat_session( logger.warning(f"Unexpected cache error for session {session_id}: {e}") # Fall back to database - logger.info(f"Session {session_id} not in cache, checking database") + logger.debug(f"Session {session_id} not in cache, checking database") session = await _get_session_from_db(session_id) if session is None: @@ -493,7 +489,6 @@ async def get_chat_session( # Cache the session from DB try: await _cache_session(session) - logger.info(f"Cached session {session_id} from database") except Exception as e: logger.warning(f"Failed to cache session {session_id}: {e}") @@ -558,6 +553,40 @@ async def upsert_chat_session( return session +async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession: + """Atomically append a message to a session and persist it. + + Acquires the session lock, re-fetches the latest session state, + appends the message, and saves — preventing message loss when + concurrent requests modify the same session. + """ + lock = await _get_session_lock(session_id) + + async with lock: + session = await get_chat_session(session_id) + if session is None: + raise ValueError(f"Session {session_id} not found") + + session.messages.append(message) + existing_message_count = await chat_db.get_chat_session_message_count( + session_id + ) + + try: + await _save_session_to_db(session, existing_message_count) + except Exception as e: + raise DatabaseError( + f"Failed to persist message to session {session_id}" + ) from e + + try: + await _cache_session(session) + except Exception as e: + logger.warning(f"Cache write failed for session {session_id}: {e}") + + return session + + async def create_chat_session(user_id: str) -> ChatSession: """Create a new chat session and persist it. @@ -664,13 +693,19 @@ async def update_session_title(session_id: str, title: str) -> bool: logger.warning(f"Session {session_id} not found for title update") return False - # Invalidate cache so next fetch gets updated title + # Update title in cache if it exists (instead of invalidating). + # This prevents race conditions where cache invalidation causes + # the frontend to see stale DB data while streaming is still in progress. try: - redis_key = _get_session_cache_key(session_id) - async_redis = await get_redis_async() - await async_redis.delete(redis_key) + cached = await _get_session_from_cache(session_id) + if cached: + cached.title = title + await _cache_session(cached) except Exception as e: - logger.warning(f"Failed to invalidate cache for session {session_id}: {e}") + # Not critical - title will be correct on next full cache refresh + logger.warning( + f"Failed to update title in cache for session {session_id}: {e}" + ) return True except Exception as e: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index c6f37569b7..aa565ca891 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,5 +1,6 @@ """Chat API routes for chat session management and streaming via SSE.""" +import asyncio import logging import uuid as uuid_module from collections.abc import AsyncGenerator @@ -11,19 +12,29 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError +from backend.util.feature_flag import Flag, is_feature_enabled from . import service as chat_service from . import stream_registry from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig -from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions -from .response_model import StreamFinish, StreamHeartbeat +from .model import ( + ChatMessage, + ChatSession, + append_and_save_message, + create_chat_session, + get_chat_session, + get_user_sessions, +) +from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart +from .sdk import service as sdk_service from .tools.models import ( AgentDetailsResponse, AgentOutputResponse, AgentPreviewResponse, AgentSavedResponse, AgentsFoundResponse, + BlockDetailsResponse, BlockListResponse, BlockOutputResponse, ClarificationNeededResponse, @@ -40,6 +51,7 @@ from .tools.models import ( SetupRequirementsResponse, UnderstandingUpdatedResponse, ) +from .tracking import track_user_message config = ChatConfig() @@ -231,6 +243,10 @@ async def get_session( active_task, last_message_id = await stream_registry.get_active_task_for_session( session_id, user_id ) + logger.info( + f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, " + f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}" + ) if active_task: # Filter out the in-progress assistant message from the session response. # The client will receive the complete assistant response through the SSE @@ -300,10 +316,9 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - session = await _validate_and_get_session(session_id, user_id) logger.info( - f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms", + f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -312,6 +327,25 @@ async def stream_chat_post( }, ) + # Atomically append user message to session BEFORE creating task to avoid + # race condition where GET_SESSION sees task as "running" but message isn't + # saved yet. append_and_save_message re-fetches inside a lock to prevent + # message loss from concurrent requests. + if request.message: + message = ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, + ) + if request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + logger.info(f"[STREAM] Saving user message to session {session_id}") + session = await append_and_save_message(session_id, message) + logger.info(f"[STREAM] User message saved for session {session_id}") + # Create a task in the stream registry for reconnection support task_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4()) @@ -327,7 +361,7 @@ async def stream_chat_post( operation_id=operation_id, ) logger.info( - f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms", + f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -348,15 +382,47 @@ async def stream_chat_post( first_chunk_time, ttfc = None, None chunk_count = 0 try: - async for chunk in chat_service.stream_chat_completion( + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + logger.info( + f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "elapsed_ms": (time_module.perf_counter() - gen_start_time) + * 1000, + } + }, + ) + + # Choose service based on LaunchDarkly flag (falls back to config default) + use_sdk = await is_feature_enabled( + Flag.COPILOT_SDK, + user_id or "anonymous", + default=config.use_claude_agent_sdk, + ) + stream_fn = ( + sdk_service.stream_chat_completion_sdk + if use_sdk + else chat_service.stream_chat_completion + ) + logger.info( + f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion", + extra={"json_fields": log_meta}, + ) + # Pass message=None since we already added it to the session above + async for chunk in stream_fn( session_id, - request.message, + None, # Message already in session is_user_message=request.is_user_message, user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch + session=session, # Pass session with message already added context=request.context, - _task_id=task_id, # Pass task_id so service emits start with taskId for reconnection ): + # Skip duplicate StreamStart — we already published one above + if isinstance(chunk, StreamStart): + continue chunk_count += 1 if first_chunk_time is None: first_chunk_time = time_module.perf_counter() @@ -377,7 +443,7 @@ async def stream_chat_post( gen_end_time = time_module.perf_counter() total_time = (gen_end_time - gen_start_time) * 1000 logger.info( - f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; " + f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; " f"task={task_id}, session={session_id}, " f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}", extra={ @@ -404,6 +470,17 @@ async def stream_chat_post( } }, ) + # Publish a StreamError so the frontend can display an error message + try: + await stream_registry.publish_chunk( + task_id, + StreamError( + errorText="An error occurred. Please try again.", + code="stream_error", + ), + ) + except Exception: + pass # Best-effort; mark_task_completed will publish StreamFinish await stream_registry.mark_task_completed(task_id, "failed") # Start the AI generation in a background task @@ -506,8 +583,14 @@ async def stream_chat_post( "json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)} }, ) + # Surface error to frontend so it doesn't appear stuck + yield StreamError( + errorText="An error occurred. Please try again.", + code="stream_error", + ).to_sse() + yield StreamFinish().to_sse() finally: - # Unsubscribe when client disconnects or stream ends to prevent resource leak + # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: await stream_registry.unsubscribe_from_task( @@ -751,8 +834,6 @@ async def stream_task( ) async def event_generator() -> AsyncGenerator[str, None]: - import asyncio - heartbeat_interval = 15.0 # Send heartbeat every 15 seconds try: while True: @@ -971,6 +1052,7 @@ ToolResponseUnion = ( | AgentSavedResponse | ClarificationNeededResponse | BlockListResponse + | BlockDetailsResponse | BlockOutputResponse | DocSearchResultsResponse | DocPageResponse diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py new file mode 100644 index 0000000000..7d9d6371e9 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py @@ -0,0 +1,14 @@ +"""Claude Agent SDK integration for CoPilot. + +This module provides the integration layer between the Claude Agent SDK +and the existing CoPilot tool system, enabling drop-in replacement of +the current LLM orchestration with the battle-tested Claude Agent SDK. +""" + +from .service import stream_chat_completion_sdk +from .tool_adapter import create_copilot_mcp_server + +__all__ = [ + "stream_chat_completion_sdk", + "create_copilot_mcp_server", +] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py new file mode 100644 index 0000000000..f7151f8319 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py @@ -0,0 +1,203 @@ +"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + +This module provides the adapter layer that converts streaming messages from +the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that +the frontend expects. +""" + +import json +import logging +import uuid + +from claude_agent_sdk import ( + AssistantMessage, + Message, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) +from backend.api.features.chat.sdk.tool_adapter import ( + MCP_TOOL_PREFIX, + pop_pending_tool_output, +) + +logger = logging.getLogger(__name__) + + +class SDKResponseAdapter: + """Adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + + This class maintains state during a streaming session to properly track + text blocks, tool calls, and message lifecycle. + """ + + def __init__(self, message_id: str | None = None): + self.message_id = message_id or str(uuid.uuid4()) + self.text_block_id = str(uuid.uuid4()) + self.has_started_text = False + self.has_ended_text = False + self.current_tool_calls: dict[str, dict[str, str]] = {} + self.task_id: str | None = None + self.step_open = False + + def set_task_id(self, task_id: str) -> None: + """Set the task ID for reconnection support.""" + self.task_id = task_id + + def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]: + """Convert a single SDK message to Vercel AI SDK format.""" + responses: list[StreamBaseResponse] = [] + + if isinstance(sdk_message, SystemMessage): + if sdk_message.subtype == "init": + responses.append( + StreamStart(messageId=self.message_id, taskId=self.task_id) + ) + # Open the first step (matches non-SDK: StreamStart then StreamStartStep) + responses.append(StreamStartStep()) + self.step_open = True + + elif isinstance(sdk_message, AssistantMessage): + # After tool results, the SDK sends a new AssistantMessage for the + # next LLM turn. Open a new step if the previous one was closed. + if not self.step_open: + responses.append(StreamStartStep()) + self.step_open = True + + for block in sdk_message.content: + if isinstance(block, TextBlock): + if block.text: + self._ensure_text_started(responses) + responses.append( + StreamTextDelta(id=self.text_block_id, delta=block.text) + ) + + elif isinstance(block, ToolUseBlock): + self._end_text_if_open(responses) + + # Strip MCP prefix so frontend sees "find_block" + # instead of "mcp__copilot__find_block". + tool_name = block.name.removeprefix(MCP_TOOL_PREFIX) + + responses.append( + StreamToolInputStart(toolCallId=block.id, toolName=tool_name) + ) + responses.append( + StreamToolInputAvailable( + toolCallId=block.id, + toolName=tool_name, + input=block.input, + ) + ) + self.current_tool_calls[block.id] = {"name": tool_name} + + elif isinstance(sdk_message, UserMessage): + # UserMessage carries tool results back from tool execution. + content = sdk_message.content + blocks = content if isinstance(content, list) else [] + for block in blocks: + if isinstance(block, ToolResultBlock) and block.tool_use_id: + tool_info = self.current_tool_calls.get(block.tool_use_id, {}) + tool_name = tool_info.get("name", "unknown") + + # Prefer the stashed full output over the SDK's + # (potentially truncated) ToolResultBlock content. + # The SDK truncates large results, writing them to disk, + # which breaks frontend widget parsing. + output = pop_pending_tool_output(tool_name) or ( + _extract_tool_output(block.content) + ) + + responses.append( + StreamToolOutputAvailable( + toolCallId=block.tool_use_id, + toolName=tool_name, + output=output, + success=not (block.is_error or False), + ) + ) + + # Close the current step after tool results — the next + # AssistantMessage will open a new step for the continuation. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + elif isinstance(sdk_message, ResultMessage): + self._end_text_if_open(responses) + # Close the step before finishing. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + if sdk_message.subtype == "success": + responses.append(StreamFinish()) + elif sdk_message.subtype in ("error", "error_during_execution"): + error_msg = getattr(sdk_message, "result", None) or "Unknown error" + responses.append( + StreamError(errorText=str(error_msg), code="sdk_error") + ) + responses.append(StreamFinish()) + else: + logger.warning( + f"Unexpected ResultMessage subtype: {sdk_message.subtype}" + ) + responses.append(StreamFinish()) + + else: + logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}") + + return responses + + def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None: + """Start (or restart) a text block if needed.""" + if not self.has_started_text or self.has_ended_text: + if self.has_ended_text: + self.text_block_id = str(uuid.uuid4()) + self.has_ended_text = False + responses.append(StreamTextStart(id=self.text_block_id)) + self.has_started_text = True + + def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None: + """End the current text block if one is open.""" + if self.has_started_text and not self.has_ended_text: + responses.append(StreamTextEnd(id=self.text_block_id)) + self.has_ended_text = True + + +def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str: + """Extract a string output from a ToolResultBlock's content field.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [item.get("text", "") for item in content if item.get("type") == "text"] + if parts: + return "".join(parts) + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) + if content is None: + return "" + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py new file mode 100644 index 0000000000..a4f2502642 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py @@ -0,0 +1,366 @@ +"""Unit tests for the SDK response adapter.""" + +from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) + +from .response_adapter import SDKResponseAdapter +from .tool_adapter import MCP_TOOL_PREFIX + + +def _adapter() -> SDKResponseAdapter: + a = SDKResponseAdapter(message_id="msg-1") + a.set_task_id("task-1") + return a + + +# -- SystemMessage ----------------------------------------------------------- + + +def test_system_init_emits_start_and_step(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="init", data={})) + assert len(results) == 2 + assert isinstance(results[0], StreamStart) + assert results[0].messageId == "msg-1" + assert results[0].taskId == "task-1" + assert isinstance(results[1], StreamStartStep) + + +def test_system_non_init_emits_nothing(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="other", data={})) + assert results == [] + + +# -- AssistantMessage with TextBlock ----------------------------------------- + + +def test_text_block_emits_step_start_and_delta(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="hello")], model="test") + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "hello" + + +def test_empty_text_block_emits_only_step(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="")], model="test") + results = adapter.convert_message(msg) + # Empty text skipped, but step still opens + assert len(results) == 1 + assert isinstance(results[0], StreamStartStep) + + +def test_multiple_text_deltas_reuse_block_id(): + adapter = _adapter() + msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test") + msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test") + r1 = adapter.convert_message(msg1) + r2 = adapter.convert_message(msg2) + # First gets step+start+delta, second only delta (block & step already started) + assert len(r1) == 3 + assert isinstance(r1[0], StreamStartStep) + assert isinstance(r1[1], StreamTextStart) + assert len(r2) == 1 + assert isinstance(r2[0], StreamTextDelta) + assert r1[1].id == r2[0].id # same block ID + + +# -- AssistantMessage with ToolUseBlock -------------------------------------- + + +def test_tool_use_emits_input_start_and_available(): + """Tool names arrive with MCP prefix and should be stripped for the frontend.""" + adapter = _adapter() + msg = AssistantMessage( + content=[ + ToolUseBlock( + id="tool-1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"q": "x"}, + ) + ], + model="test", + ) + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamToolInputStart) + assert results[1].toolCallId == "tool-1" + assert results[1].toolName == "find_agent" # prefix stripped + assert isinstance(results[2], StreamToolInputAvailable) + assert results[2].toolName == "find_agent" # prefix stripped + assert results[2].input == {"q": "x"} + + +def test_text_then_tool_ends_text_block(): + adapter = _adapter() + text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test") + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + adapter.convert_message(text_msg) # opens step + text + results = adapter.convert_message(tool_msg) + # Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamToolInputStart) + + +# -- UserMessage with ToolResultBlock ---------------------------------------- + + +def test_tool_result_emits_output_and_finish_step(): + adapter = _adapter() + # First register the tool call (opens step) — SDK sends prefixed name + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})], + model="test", + ) + adapter.convert_message(tool_msg) + + # Now send tool result + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")] + ) + results = adapter.convert_message(result_msg) + assert len(results) == 2 + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].toolCallId == "t1" + assert results[0].toolName == "find_agent" # prefix stripped + assert results[0].output == "found 3 agents" + assert results[0].success is True + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_error(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={}) + ], + model="test", + ) + ) + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].success is False + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_list_content(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + result_msg = UserMessage( + content=[ + ToolResultBlock( + tool_use_id="t1", + content=[ + {"type": "text", "text": "line1"}, + {"type": "text", "text": "line2"}, + ], + ) + ] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].output == "line1line2" + assert isinstance(results[1], StreamFinishStep) + + +def test_string_user_message_ignored(): + """A plain string UserMessage (not tool results) produces no output.""" + adapter = _adapter() + results = adapter.convert_message(UserMessage(content="hello")) + assert results == [] + + +# -- ResultMessage ----------------------------------------------------------- + + +def test_result_success_emits_finish_step_and_finish(): + adapter = _adapter() + # Start some text first (opens step) + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="done")], model="test") + ) + msg = ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="s1", + ) + results = adapter.convert_message(msg) + # TextEnd + FinishStep + StreamFinish + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamFinishStep) + assert isinstance(results[2], StreamFinish) + + +def test_result_error_emits_error_and_finish(): + adapter = _adapter() + msg = ResultMessage( + subtype="error", + duration_ms=100, + duration_api_ms=50, + is_error=True, + num_turns=0, + session_id="s1", + result="API rate limited", + ) + results = adapter.convert_message(msg) + # No step was open, so no FinishStep — just Error + Finish + assert len(results) == 2 + assert isinstance(results[0], StreamError) + assert "API rate limited" in results[0].errorText + assert isinstance(results[1], StreamFinish) + + +# -- Text after tools (new block ID) ---------------------------------------- + + +def test_text_after_tool_gets_new_block_id(): + adapter = _adapter() + # Text -> Tool -> ToolResult -> Text should get a new text block ID and step + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="before")], model="test") + ) + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + # Send tool result (closes step) + adapter.convert_message( + UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")]) + ) + results = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="after")], model="test") + ) + # Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "after" + + +# -- Full conversation flow -------------------------------------------------- + + +def test_full_conversation_flow(): + """Simulate a complete conversation: init -> text -> tool -> result -> text -> finish.""" + adapter = _adapter() + all_responses: list[StreamBaseResponse] = [] + + # 1. Init + all_responses.extend( + adapter.convert_message(SystemMessage(subtype="init", data={})) + ) + # 2. Assistant text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="Let me search")], model="test") + ) + ) + # 3. Tool use + all_responses.extend( + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock( + id="t1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"query": "email"}, + ) + ], + model="test", + ) + ) + ) + # 4. Tool result + all_responses.extend( + adapter.convert_message( + UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")] + ) + ) + ) + # 5. More text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="I found 2")], model="test") + ) + ) + # 6. Result + all_responses.extend( + adapter.convert_message( + ResultMessage( + subtype="success", + duration_ms=500, + duration_api_ms=400, + is_error=False, + num_turns=2, + session_id="s1", + ) + ) + ) + + types = [type(r).__name__ for r in all_responses] + assert types == [ + "StreamStart", + "StreamStartStep", # step 1: text + tool call + "StreamTextStart", + "StreamTextDelta", # "Let me search" + "StreamTextEnd", # closed before tool + "StreamToolInputStart", + "StreamToolInputAvailable", + "StreamToolOutputAvailable", # tool result + "StreamFinishStep", # step 1 closed after tool result + "StreamStartStep", # step 2: continuation text + "StreamTextStart", # new block after tool + "StreamTextDelta", # "I found 2" + "StreamTextEnd", # closed by result + "StreamFinishStep", # step 2 closed + "StreamFinish", + ] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py new file mode 100644 index 0000000000..89853402a3 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py @@ -0,0 +1,305 @@ +"""Security hooks for Claude Agent SDK integration. + +This module provides security hooks that validate tool calls before execution, +ensuring multi-user isolation and preventing unauthorized operations. +""" + +import json +import logging +import os +import re +from collections.abc import Callable +from typing import Any, cast + +from backend.api.features.chat.sdk.tool_adapter import ( + BLOCKED_TOOLS, + DANGEROUS_PATTERNS, + MCP_TOOL_PREFIX, + WORKSPACE_SCOPED_TOOLS, +) + +logger = logging.getLogger(__name__) + + +def _deny(reason: str) -> dict[str, Any]: + """Return a hook denial response.""" + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": reason, + } + } + + +def _validate_workspace_path( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None +) -> dict[str, Any]: + """Validate that a workspace-scoped tool only accesses allowed paths. + + Allowed directories: + - The SDK working directory (``/tmp/copilot-/``) + - The SDK tool-results directory (``~/.claude/projects/…/tool-results/``) + """ + path = tool_input.get("file_path") or tool_input.get("path") or "" + if not path: + # Glob/Grep without a path default to cwd which is already sandboxed + return {} + + # Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM + # naturally uses relative paths like "test.txt" instead of absolute ones). + # Tilde paths (~/) are home-dir references, not relative — expand first. + if path.startswith("~"): + resolved = os.path.realpath(os.path.expanduser(path)) + elif not os.path.isabs(path) and sdk_cwd: + resolved = os.path.realpath(os.path.join(sdk_cwd, path)) + else: + resolved = os.path.realpath(path) + + # Allow access within the SDK working directory + if sdk_cwd: + norm_cwd = os.path.realpath(sdk_cwd) + if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd: + return {} + + # Allow access to ~/.claude/projects/*/tool-results/ (big tool results) + claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects")) + tool_results_seg = os.sep + "tool-results" + os.sep + if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved: + return {} + + logger.warning( + f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})" + ) + workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else "" + return _deny( + f"[SECURITY] Tool '{tool_name}' can only access files within the workspace " + f"directory.{workspace_hint} " + "This is enforced by the platform and cannot be bypassed." + ) + + +def _validate_tool_access( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None +) -> dict[str, Any]: + """Validate that a tool call is allowed. + + Returns: + Empty dict to allow, or dict with hookSpecificOutput to deny + """ + # Block forbidden tools + if tool_name in BLOCKED_TOOLS: + logger.warning(f"Blocked tool access attempt: {tool_name}") + return _deny( + f"[SECURITY] Tool '{tool_name}' is blocked for security. " + "This is enforced by the platform and cannot be bypassed. " + "Use the CoPilot-specific MCP tools instead." + ) + + # Workspace-scoped tools: allowed only within the SDK workspace directory + if tool_name in WORKSPACE_SCOPED_TOOLS: + return _validate_workspace_path(tool_name, tool_input, sdk_cwd) + + # Check for dangerous patterns in tool input + # Use json.dumps for predictable format (str() produces Python repr) + input_str = json.dumps(tool_input) if tool_input else "" + + for pattern in DANGEROUS_PATTERNS: + if re.search(pattern, input_str, re.IGNORECASE): + logger.warning( + f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}" + ) + return _deny( + "[SECURITY] Input contains a blocked pattern. " + "This is enforced by the platform and cannot be bypassed." + ) + + return {} + + +def _validate_user_isolation( + tool_name: str, tool_input: dict[str, Any], user_id: str | None +) -> dict[str, Any]: + """Validate that tool calls respect user isolation.""" + # For workspace file tools, ensure path doesn't escape + if "workspace" in tool_name.lower(): + path = tool_input.get("path", "") or tool_input.get("file_path", "") + if path: + # Check for path traversal + if ".." in path or path.startswith("/"): + logger.warning( + f"Blocked path traversal attempt: {path} by user {user_id}" + ) + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": "Path traversal not allowed", + } + } + + return {} + + +def create_security_hooks( + user_id: str | None, + sdk_cwd: str | None = None, + max_subtasks: int = 3, + on_stop: Callable[[str, str], None] | None = None, +) -> dict[str, Any]: + """Create the security hooks configuration for Claude Agent SDK. + + Includes security validation and observability hooks: + - PreToolUse: Security validation before tool execution + - PostToolUse: Log successful tool executions + - PostToolUseFailure: Log and handle failed tool executions + - PreCompact: Log context compaction events (SDK handles compaction automatically) + - Stop: Capture transcript path for stateless resume (when *on_stop* is provided) + + Args: + user_id: Current user ID for isolation validation + sdk_cwd: SDK working directory for workspace-scoped tool validation + max_subtasks: Maximum Task (sub-agent) spawns allowed per session + on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when + the SDK finishes processing — used to read the JSONL transcript + before the CLI process exits. + + Returns: + Hooks configuration dict for ClaudeAgentOptions + """ + try: + from claude_agent_sdk import HookMatcher + from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput + + # Per-session counter for Task sub-agent spawns + task_spawn_count = 0 + + async def pre_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Combined pre-tool-use validation hook.""" + nonlocal task_spawn_count + _ = context # unused but required by signature + tool_name = cast(str, input_data.get("tool_name", "")) + tool_input = cast(dict[str, Any], input_data.get("tool_input", {})) + + # Rate-limit Task (sub-agent) spawns per session + if tool_name == "Task": + task_spawn_count += 1 + if task_spawn_count > max_subtasks: + logger.warning( + f"[SDK] Task limit reached ({max_subtasks}), user={user_id}" + ) + return cast( + SyncHookJSONOutput, + _deny( + f"Maximum {max_subtasks} sub-tasks per session. " + "Please continue in the main conversation." + ), + ) + + # Strip MCP prefix for consistent validation + is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX) + clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX) + + # Only block non-CoPilot tools; our MCP-registered tools + # (including Read for oversized results) are already sandboxed. + if not is_copilot_tool: + result = _validate_tool_access(clean_name, tool_input, sdk_cwd) + if result: + return cast(SyncHookJSONOutput, result) + + # Validate user isolation + result = _validate_user_isolation(clean_name, tool_input, user_id) + if result: + return cast(SyncHookJSONOutput, result) + + logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log successful tool executions for observability.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_failure_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log failed tool executions for debugging.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + error = input_data.get("error", "Unknown error") + logger.warning( + f"[SDK] Tool failed: {tool_name}, error={error}, " + f"user={user_id}, tool_use_id={tool_use_id}" + ) + return cast(SyncHookJSONOutput, {}) + + async def pre_compact_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log when SDK triggers context compaction. + + The SDK automatically compacts conversation history when it grows too large. + This hook provides visibility into when compaction happens. + """ + _ = context, tool_use_id + trigger = input_data.get("trigger", "auto") + logger.info( + f"[SDK] Context compaction triggered: {trigger}, user={user_id}" + ) + return cast(SyncHookJSONOutput, {}) + + # --- Stop hook: capture transcript path for stateless resume --- + async def stop_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Capture transcript path when SDK finishes processing. + + The Stop hook fires while the CLI process is still alive, giving us + a reliable window to read the JSONL transcript before SIGTERM. + """ + _ = context, tool_use_id + transcript_path = cast(str, input_data.get("transcript_path", "")) + sdk_session_id = cast(str, input_data.get("session_id", "")) + + if transcript_path and on_stop: + logger.info( + f"[SDK] Stop hook: transcript_path={transcript_path}, " + f"sdk_session_id={sdk_session_id[:12]}..." + ) + on_stop(transcript_path, sdk_session_id) + + return cast(SyncHookJSONOutput, {}) + + hooks: dict[str, Any] = { + "PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])], + "PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])], + "PostToolUseFailure": [ + HookMatcher(matcher="*", hooks=[post_tool_failure_hook]) + ], + "PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])], + } + + if on_stop is not None: + hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])] + + return hooks + except ImportError: + # Fallback for when SDK isn't available - return empty hooks + logger.warning("claude-agent-sdk not available, security hooks disabled") + return {} diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py new file mode 100644 index 0000000000..2d09afdab7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py @@ -0,0 +1,165 @@ +"""Unit tests for SDK security hooks.""" + +import os + +from .security_hooks import _validate_tool_access, _validate_user_isolation + +SDK_CWD = "/tmp/copilot-abc123" + + +def _is_denied(result: dict) -> bool: + hook = result.get("hookSpecificOutput", {}) + return hook.get("permissionDecision") == "deny" + + +# -- Blocked tools ----------------------------------------------------------- + + +def test_blocked_tools_denied(): + for tool in ("bash", "shell", "exec", "terminal", "command"): + result = _validate_tool_access(tool, {}) + assert _is_denied(result), f"{tool} should be blocked" + + +def test_unknown_tool_allowed(): + result = _validate_tool_access("SomeCustomTool", {}) + assert result == {} + + +# -- Workspace-scoped tools -------------------------------------------------- + + +def test_read_within_workspace_allowed(): + result = _validate_tool_access( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_write_within_workspace_allowed(): + result = _validate_tool_access( + "Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_edit_within_workspace_allowed(): + result = _validate_tool_access( + "Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_glob_within_workspace_allowed(): + result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_grep_within_workspace_allowed(): + result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_outside_workspace_denied(): + result = _validate_tool_access( + "Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_write_outside_workspace_denied(): + result = _validate_tool_access( + "Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_traversal_attack_denied(): + result = _validate_tool_access( + "Read", + {"file_path": f"{SDK_CWD}/../../etc/passwd"}, + sdk_cwd=SDK_CWD, + ) + assert _is_denied(result) + + +def test_no_path_allowed(): + """Glob/Grep without a path argument defaults to cwd — should pass.""" + result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_no_cwd_denies_absolute(): + """If no sdk_cwd is set, absolute paths are denied.""" + result = _validate_tool_access("Read", {"file_path": "/tmp/anything"}) + assert _is_denied(result) + + +# -- Tool-results directory -------------------------------------------------- + + +def test_read_tool_results_allowed(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_claude_projects_without_tool_results_denied(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Built-in Bash is blocked (use bash_exec MCP tool instead) --------------- + + +def test_bash_builtin_always_blocked(): + """SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead.""" + result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Dangerous patterns ------------------------------------------------------ + + +def test_dangerous_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"}) + assert _is_denied(result) + + +def test_subprocess_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"}) + assert _is_denied(result) + + +# -- User isolation ---------------------------------------------------------- + + +def test_workspace_path_traversal_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_absolute_path_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "/etc/passwd"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_normal_path_allowed(): + result = _validate_user_isolation( + "workspace_read", {"path": "src/main.py"}, user_id="user-1" + ) + assert result == {} + + +def test_non_workspace_tool_passes_isolation(): + result = _validate_user_isolation( + "find_agent", {"query": "email"}, user_id="user-1" + ) + assert result == {} diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py new file mode 100644 index 0000000000..65c4cebb06 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -0,0 +1,752 @@ +"""Claude Agent SDK service layer for CoPilot chat completions.""" + +import asyncio +import json +import logging +import os +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Any + +from backend.util.exceptions import NotFoundError + +from .. import stream_registry +from ..config import ChatConfig +from ..model import ( + ChatMessage, + ChatSession, + get_chat_session, + update_session_title, + upsert_chat_session, +) +from ..response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamStart, + StreamTextDelta, + StreamToolInputAvailable, + StreamToolOutputAvailable, +) +from ..service import ( + _build_system_prompt, + _execute_long_running_tool_with_streaming, + _generate_session_title, +) +from ..tools.models import OperationPendingResponse, OperationStartedResponse +from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path +from ..tracking import track_user_message +from .response_adapter import SDKResponseAdapter +from .security_hooks import create_security_hooks +from .tool_adapter import ( + COPILOT_TOOL_NAMES, + SDK_DISALLOWED_TOOLS, + LongRunningCallback, + create_copilot_mcp_server, + set_execution_context, +) +from .transcript import ( + download_transcript, + read_transcript_file, + upload_transcript, + validate_transcript, + write_transcript_to_tempfile, +) + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Set to hold background tasks to prevent garbage collection +_background_tasks: set[asyncio.Task[Any]] = set() + + +@dataclass +class CapturedTranscript: + """Info captured by the SDK Stop hook for stateless --resume.""" + + path: str = "" + sdk_session_id: str = "" + + @property + def available(self) -> bool: + return bool(self.path) + + +_SDK_CWD_PREFIX = WORKSPACE_PREFIX + +# Appended to the system prompt to inform the agent about available tools. +# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead, +# which has kernel-level network isolation (unshare --net). +_SDK_TOOL_SUPPLEMENT = """ + +## Tool notes + +- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool + for shell commands — it runs in a network-isolated sandbox. +- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the + same working directory. Files created by one are readable by the other. + These files are **ephemeral** — they exist only for the current session. +- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file` + for files that should persist across sessions (stored in cloud storage). +- Long-running tools (create_agent, edit_agent, etc.) are handled + asynchronously. You will receive an immediate response; the actual result + is delivered to the user via a background stream. +""" + + +def _build_long_running_callback(user_id: str | None) -> LongRunningCallback: + """Build a callback that delegates long-running tools to the non-SDK infrastructure. + + Long-running tools (create_agent, edit_agent, etc.) are delegated to the + existing background infrastructure: stream_registry (Redis Streams), + database persistence, and SSE reconnection. This means results survive + page refreshes / pod restarts, and the frontend shows the proper loading + widget with progress updates. + + The returned callback matches the ``LongRunningCallback`` signature: + ``(tool_name, args, session) -> MCP response dict``. + """ + + async def _callback( + tool_name: str, args: dict[str, Any], session: ChatSession + ) -> dict[str, Any]: + operation_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}" + session_id = session.session_id + + # --- Build user-friendly messages (matches non-SDK service) --- + if tool_name == "create_agent": + desc = args.get("description", "") + desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc + pending_msg = ( + f"Creating your agent: {desc_preview}" + if desc_preview + else "Creating agent... This may take a few minutes." + ) + started_msg = ( + "Agent creation started. You can close this tab - " + "check your library in a few minutes." + ) + elif tool_name == "edit_agent": + changes = args.get("changes", "") + changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes + pending_msg = ( + f"Editing agent: {changes_preview}" + if changes_preview + else "Editing agent... This may take a few minutes." + ) + started_msg = ( + "Agent edit started. You can close this tab - " + "check your library in a few minutes." + ) + else: + pending_msg = f"Running {tool_name}... This may take a few minutes." + started_msg = ( + f"{tool_name} started. You can close this tab - " + "check back in a few minutes." + ) + + # --- Register task in Redis for SSE reconnection --- + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # --- Save OperationPendingResponse to chat history --- + pending_message = ChatMessage( + role="tool", + content=OperationPendingResponse( + message=pending_msg, + operation_id=operation_id, + tool_name=tool_name, + ).model_dump_json(), + tool_call_id=tool_call_id, + ) + session.messages.append(pending_message) + await upsert_chat_session(session) + + # --- Spawn background task (reuses non-SDK infrastructure) --- + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( + tool_name=tool_name, + parameters=args, + tool_call_id=tool_call_id, + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + user_id=user_id, + ) + ) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + logger.info( + f"[SDK] Long-running tool {tool_name} delegated to background " + f"(operation_id={operation_id}, task_id={task_id})" + ) + + # --- Return OperationStartedResponse as MCP tool result --- + # This flows through SDK → response adapter → frontend, triggering + # the loading widget with SSE reconnection support. + started_json = OperationStartedResponse( + message=started_msg, + operation_id=operation_id, + tool_name=tool_name, + task_id=task_id, + ).model_dump_json() + + return { + "content": [{"type": "text", "text": started_json}], + "isError": False, + } + + return _callback + + +def _resolve_sdk_model() -> str | None: + """Resolve the model name for the Claude Agent SDK CLI. + + Uses ``config.claude_agent_model`` if set, otherwise derives from + ``config.model`` by stripping the OpenRouter provider prefix (e.g., + ``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``). + """ + if config.claude_agent_model: + return config.claude_agent_model + model = config.model + if "/" in model: + return model.split("/", 1)[1] + return model + + +def _build_sdk_env() -> dict[str, str]: + """Build env vars for the SDK CLI process. + + Routes API calls through OpenRouter (or a custom base_url) using + the same ``config.api_key`` / ``config.base_url`` as the non-SDK path. + This gives per-call token and cost tracking on the OpenRouter dashboard. + + Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth + token are both present — otherwise returns an empty dict so the SDK + falls back to its default credentials. + """ + env: dict[str, str] = {} + if config.api_key and config.base_url: + # Strip /v1 suffix — SDK expects the base URL without a version path + base = config.base_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + if not base or not base.startswith("http"): + # Invalid base_url — don't override SDK defaults + return env + env["ANTHROPIC_BASE_URL"] = base + env["ANTHROPIC_AUTH_TOKEN"] = config.api_key + # Must be explicitly empty so the CLI uses AUTH_TOKEN instead + env["ANTHROPIC_API_KEY"] = "" + return env + + +def _make_sdk_cwd(session_id: str) -> str: + """Create a safe, session-specific working directory path. + + Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path` + (single source of truth for path sanitization) and adds a defence-in-depth + assertion. + """ + cwd = make_session_path(session_id) + # Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer + cwd = os.path.normpath(cwd) + if not cwd.startswith(_SDK_CWD_PREFIX): + raise ValueError(f"SDK cwd escaped prefix: {cwd}") + return cwd + + +def _cleanup_sdk_tool_results(cwd: str) -> None: + """Remove SDK tool-result files for a specific session working directory. + + The SDK creates tool-result files under ~/.claude/projects//tool-results/. + We clean only the specific cwd's results to avoid race conditions between + concurrent sessions. + + Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id. + """ + import shutil + + # Validate cwd is under the expected prefix + normalized = os.path.normpath(cwd) + if not normalized.startswith(_SDK_CWD_PREFIX): + logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}") + return + + # SDK encodes the cwd path by replacing '/' with '-' + encoded_cwd = normalized.replace("/", "-") + + # Construct the project directory path (known-safe home expansion) + claude_projects = os.path.expanduser("~/.claude/projects") + project_dir = os.path.join(claude_projects, encoded_cwd) + + # Security check 3: Validate project_dir is under ~/.claude/projects + project_dir = os.path.normpath(project_dir) + if not project_dir.startswith(claude_projects): + logger.warning( + f"[SDK] Rejecting cleanup for escaped project path: {project_dir}" + ) + return + + results_dir = os.path.join(project_dir, "tool-results") + if os.path.isdir(results_dir): + for filename in os.listdir(results_dir): + file_path = os.path.join(results_dir, filename) + try: + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + pass + + # Also clean up the temp cwd directory itself + try: + shutil.rmtree(normalized, ignore_errors=True) + except OSError: + pass + + +async def _compress_conversation_history( + session: ChatSession, +) -> list[ChatMessage]: + """Compress prior conversation messages if they exceed the token threshold. + + Uses the shared compress_context() from prompt.py which supports: + - LLM summarization of old messages (keeps recent ones intact) + - Progressive content truncation as fallback + - Middle-out deletion as last resort + + Returns the compressed prior messages (everything except the current message). + """ + prior = session.messages[:-1] + if len(prior) < 2: + return prior + + from backend.util.prompt import compress_context + + # Convert ChatMessages to dicts for compress_context + messages_dict = [] + for msg in prior: + msg_dict: dict[str, Any] = {"role": msg.role} + if msg.content: + msg_dict["content"] = msg.content + if msg.tool_calls: + msg_dict["tool_calls"] = msg.tool_calls + if msg.tool_call_id: + msg_dict["tool_call_id"] = msg.tool_call_id + messages_dict.append(msg_dict) + + try: + import openai + + async with openai.AsyncOpenAI( + api_key=config.api_key, base_url=config.base_url, timeout=30.0 + ) as client: + result = await compress_context( + messages=messages_dict, + model=config.model, + client=client, + ) + except Exception as e: + logger.warning(f"[SDK] Context compression with LLM failed: {e}") + # Fall back to truncation-only (no LLM summarization) + result = await compress_context( + messages=messages_dict, + model=config.model, + client=None, + ) + + if result.was_compacted: + logger.info( + f"[SDK] Context compacted: {result.original_token_count} -> " + f"{result.token_count} tokens " + f"({result.messages_summarized} summarized, " + f"{result.messages_dropped} dropped)" + ) + # Convert compressed dicts back to ChatMessages + return [ + ChatMessage( + role=m["role"], + content=m.get("content"), + tool_calls=m.get("tool_calls"), + tool_call_id=m.get("tool_call_id"), + ) + for m in result.messages + ] + + return prior + + +def _format_conversation_context(messages: list[ChatMessage]) -> str | None: + """Format conversation messages into a context prefix for the user message. + + Returns a string like: + + User: hello + You responded: Hi! How can I help? + + + Returns None if there are no messages to format. + """ + if not messages: + return None + + lines: list[str] = [] + for msg in messages: + if not msg.content: + continue + if msg.role == "user": + lines.append(f"User: {msg.content}") + elif msg.role == "assistant": + lines.append(f"You responded: {msg.content}") + # Skip tool messages — they're internal details + + if not lines: + return None + + return "\n" + "\n".join(lines) + "\n" + + +async def stream_chat_completion_sdk( + session_id: str, + message: str | None = None, + tool_call_response: str | None = None, # noqa: ARG001 + is_user_message: bool = True, + user_id: str | None = None, + retry_count: int = 0, # noqa: ARG001 + session: ChatSession | None = None, + context: dict[str, str] | None = None, # noqa: ARG001 +) -> AsyncGenerator[StreamBaseResponse, None]: + """Stream chat completion using Claude Agent SDK. + + Drop-in replacement for stream_chat_completion with improved reliability. + """ + + if session is None: + session = await get_chat_session(session_id, user_id) + + if not session: + raise NotFoundError( + f"Session {session_id} not found. Please create a new session first." + ) + + if message: + session.messages.append( + ChatMessage( + role="user" if is_user_message else "assistant", content=message + ) + ) + if is_user_message: + track_user_message( + user_id=user_id, session_id=session_id, message_length=len(message) + ) + + session = await upsert_chat_session(session) + + # Generate title for new sessions (first user message) + if is_user_message and not session.title: + user_messages = [m for m in session.messages if m.role == "user"] + if len(user_messages) == 1: + first_message = user_messages[0].content or message or "" + if first_message: + task = asyncio.create_task( + _update_title_async(session_id, first_message, user_id) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + + # Build system prompt (reuses non-SDK path with Langfuse support) + has_history = len(session.messages) > 1 + system_prompt, _ = await _build_system_prompt( + user_id, has_conversation_history=has_history + ) + system_prompt += _SDK_TOOL_SUPPLEMENT + message_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + + yield StreamStart(messageId=message_id, taskId=task_id) + + stream_completed = False + # Initialise sdk_cwd before the try so the finally can reference it + # even if _make_sdk_cwd raises (in that case it stays as ""). + sdk_cwd = "" + use_resume = False + + try: + # Use a session-specific temp dir to avoid cleanup race conditions + # between concurrent sessions. + sdk_cwd = _make_sdk_cwd(session_id) + os.makedirs(sdk_cwd, exist_ok=True) + + set_execution_context( + user_id, + session, + long_running_callback=_build_long_running_callback(user_id), + ) + try: + from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient + + # Fail fast when no API credentials are available at all + sdk_env = _build_sdk_env() + if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"): + raise RuntimeError( + "No API key configured. Set OPEN_ROUTER_API_KEY " + "(or CHAT_API_KEY) for OpenRouter routing, " + "or ANTHROPIC_API_KEY for direct Anthropic access." + ) + + mcp_server = create_copilot_mcp_server() + + sdk_model = _resolve_sdk_model() + + # --- Transcript capture via Stop hook --- + captured_transcript = CapturedTranscript() + + def _on_stop(transcript_path: str, sdk_session_id: str) -> None: + captured_transcript.path = transcript_path + captured_transcript.sdk_session_id = sdk_session_id + + security_hooks = create_security_hooks( + user_id, + sdk_cwd=sdk_cwd, + max_subtasks=config.claude_agent_max_subtasks, + on_stop=_on_stop if config.claude_agent_use_resume else None, + ) + + # --- Resume strategy: download transcript from bucket --- + resume_file: str | None = None + use_resume = False + + if config.claude_agent_use_resume and user_id and len(session.messages) > 1: + transcript_content = await download_transcript(user_id, session_id) + if transcript_content and validate_transcript(transcript_content): + resume_file = write_transcript_to_tempfile( + transcript_content, session_id, sdk_cwd + ) + if resume_file: + use_resume = True + logger.info( + f"[SDK] Using --resume with transcript " + f"({len(transcript_content)} bytes)" + ) + + sdk_options_kwargs: dict[str, Any] = { + "system_prompt": system_prompt, + "mcp_servers": {"copilot": mcp_server}, + "allowed_tools": COPILOT_TOOL_NAMES, + "disallowed_tools": SDK_DISALLOWED_TOOLS, + "hooks": security_hooks, + "cwd": sdk_cwd, + "max_buffer_size": config.claude_agent_max_buffer_size, + } + if sdk_env: + sdk_options_kwargs["model"] = sdk_model + sdk_options_kwargs["env"] = sdk_env + if use_resume and resume_file: + sdk_options_kwargs["resume"] = resume_file + + options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] + + adapter = SDKResponseAdapter(message_id=message_id) + adapter.set_task_id(task_id) + + async with ClaudeSDKClient(options=options) as client: + current_message = message or "" + if not current_message and session.messages: + last_user = [m for m in session.messages if m.role == "user"] + if last_user: + current_message = last_user[-1].content or "" + + if not current_message.strip(): + yield StreamError( + errorText="Message cannot be empty.", + code="empty_prompt", + ) + yield StreamFinish() + return + + # Build query: with --resume the CLI already has full + # context, so we only send the new message. Without + # resume, compress history into a context prefix. + query_message = current_message + if not use_resume and len(session.messages) > 1: + logger.warning( + f"[SDK] Using compression fallback for session " + f"{session_id} ({len(session.messages)} messages) — " + f"no transcript available for --resume" + ) + compressed = await _compress_conversation_history(session) + history_context = _format_conversation_context(compressed) + if history_context: + query_message = ( + f"{history_context}\n\n" + f"Now, the user says:\n{current_message}" + ) + + logger.info( + f"[SDK] Sending query ({len(session.messages)} msgs in session)" + ) + logger.debug(f"[SDK] Query preview: {current_message[:80]!r}") + await client.query(query_message, session_id=session_id) + + assistant_response = ChatMessage(role="assistant", content="") + accumulated_tool_calls: list[dict[str, Any]] = [] + has_appended_assistant = False + has_tool_results = False + + async for sdk_msg in client.receive_messages(): + logger.debug( + f"[SDK] Received: {type(sdk_msg).__name__} " + f"{getattr(sdk_msg, 'subtype', '')}" + ) + for response in adapter.convert_message(sdk_msg): + if isinstance(response, StreamStart): + continue + + yield response + + if isinstance(response, StreamTextDelta): + delta = response.delta or "" + # After tool results, start a new assistant + # message for the post-tool text. + if has_tool_results and has_appended_assistant: + assistant_response = ChatMessage( + role="assistant", content=delta + ) + accumulated_tool_calls = [] + has_appended_assistant = False + has_tool_results = False + session.messages.append(assistant_response) + has_appended_assistant = True + else: + assistant_response.content = ( + assistant_response.content or "" + ) + delta + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolInputAvailable): + accumulated_tool_calls.append( + { + "id": response.toolCallId, + "type": "function", + "function": { + "name": response.toolName, + "arguments": json.dumps(response.input or {}), + }, + } + ) + assistant_response.tool_calls = accumulated_tool_calls + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolOutputAvailable): + session.messages.append( + ChatMessage( + role="tool", + content=( + response.output + if isinstance(response.output, str) + else str(response.output) + ), + tool_call_id=response.toolCallId, + ) + ) + has_tool_results = True + + elif isinstance(response, StreamFinish): + stream_completed = True + + if stream_completed: + break + + if ( + assistant_response.content or assistant_response.tool_calls + ) and not has_appended_assistant: + session.messages.append(assistant_response) + + # --- Capture transcript while CLI is still alive --- + # Must happen INSIDE async with: close() sends SIGTERM + # which kills the CLI before it can flush the JSONL. + if ( + config.claude_agent_use_resume + and user_id + and captured_transcript.available + ): + # Give CLI time to flush JSONL writes before we read + await asyncio.sleep(0.5) + raw_transcript = read_transcript_file(captured_transcript.path) + if raw_transcript: + task = asyncio.create_task( + _upload_transcript_bg(user_id, session_id, raw_transcript) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + else: + logger.debug("[SDK] Stop hook fired but transcript not usable") + + except ImportError: + raise RuntimeError( + "claude-agent-sdk is not installed. " + "Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) " + "to use the OpenAI-compatible fallback." + ) + + await upsert_chat_session(session) + logger.debug( + f"[SDK] Session {session_id} saved with {len(session.messages)} messages" + ) + if not stream_completed: + yield StreamFinish() + + except Exception as e: + logger.error(f"[SDK] Error: {e}", exc_info=True) + try: + await upsert_chat_session(session) + except Exception as save_err: + logger.error(f"[SDK] Failed to save session on error: {save_err}") + yield StreamError( + errorText="An error occurred. Please try again.", + code="sdk_error", + ) + yield StreamFinish() + finally: + if sdk_cwd: + _cleanup_sdk_tool_results(sdk_cwd) + + +async def _upload_transcript_bg( + user_id: str, session_id: str, raw_content: str +) -> None: + """Background task to strip progress entries and upload transcript.""" + try: + await upload_transcript(user_id, session_id, raw_content) + except Exception as e: + logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}") + + +async def _update_title_async( + session_id: str, message: str, user_id: str | None = None +) -> None: + """Background task to update session title.""" + try: + title = await _generate_session_title( + message, user_id=user_id, session_id=session_id + ) + if title: + await update_session_title(session_id, title) + logger.debug(f"[SDK] Generated title for {session_id}: {title}") + except Exception as e: + logger.warning(f"[SDK] Failed to update session title: {e}") diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py new file mode 100644 index 0000000000..2d259730bf --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py @@ -0,0 +1,363 @@ +"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools. + +This module provides the adapter layer that converts existing BaseTool implementations +into in-process MCP tools that can be used with the Claude Agent SDK. + +Long-running tools (``is_long_running=True``) are delegated to the non-SDK +background infrastructure (stream_registry, Redis persistence, SSE reconnection) +via a callback provided by the service layer. This avoids wasteful SDK polling +and makes results survive page refreshes. +""" + +import itertools +import json +import logging +import os +import uuid +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools import TOOL_REGISTRY +from backend.api.features.chat.tools.base import BaseTool + +logger = logging.getLogger(__name__) + +# Allowed base directory for the Read tool (SDK saves oversized tool results here). +# Restricted to ~/.claude/projects/ and further validated to require "tool-results" +# in the path — prevents reading settings, credentials, or other sensitive files. +_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/") + +# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}" +MCP_SERVER_NAME = "copilot" +MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__" + +# Context variables to pass user/session info to tool execution +_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None) +_current_session: ContextVar[ChatSession | None] = ContextVar( + "current_session", default=None +) +# Stash for MCP tool outputs before the SDK potentially truncates them. +# Keyed by tool_name → full output string. Consumed (popped) by the +# response adapter when it builds StreamToolOutputAvailable. +_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar( + "pending_tool_outputs", default=None # type: ignore[arg-type] +) + +# Callback type for delegating long-running tools to the non-SDK infrastructure. +# Args: (tool_name, arguments, session) → MCP-formatted response dict. +LongRunningCallback = Callable[ + [str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]] +] + +# ContextVar so the service layer can inject the callback per-request. +_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar( + "long_running_callback", default=None +) + + +def set_execution_context( + user_id: str | None, + session: ChatSession, + long_running_callback: LongRunningCallback | None = None, +) -> None: + """Set the execution context for tool calls. + + This must be called before streaming begins to ensure tools have access + to user_id and session information. + + Args: + user_id: Current user's ID. + session: Current chat session. + long_running_callback: Optional callback to delegate long-running tools + to the non-SDK background infrastructure (stream_registry + Redis). + """ + _current_user_id.set(user_id) + _current_session.set(session) + _pending_tool_outputs.set({}) + _long_running_callback.set(long_running_callback) + + +def get_execution_context() -> tuple[str | None, ChatSession | None]: + """Get the current execution context.""" + return ( + _current_user_id.get(), + _current_session.get(), + ) + + +def pop_pending_tool_output(tool_name: str) -> str | None: + """Pop and return the stashed full output for *tool_name*. + + The SDK CLI may truncate large tool results (writing them to disk and + replacing the content with a file reference). This stash keeps the + original MCP output so the response adapter can forward it to the + frontend for proper widget rendering. + + Returns ``None`` if nothing was stashed for *tool_name*. + """ + pending = _pending_tool_outputs.get(None) + if pending is None: + return None + return pending.pop(tool_name, None) + + +async def _execute_tool_sync( + base_tool: BaseTool, + user_id: str | None, + session: ChatSession, + args: dict[str, Any], +) -> dict[str, Any]: + """Execute a tool synchronously and return MCP-formatted response.""" + effective_id = f"sdk-{uuid.uuid4().hex[:12]}" + result = await base_tool.execute( + user_id=user_id, + session=session, + tool_call_id=effective_id, + **args, + ) + + text = ( + result.output if isinstance(result.output, str) else json.dumps(result.output) + ) + + # Stash the full output before the SDK potentially truncates it. + pending = _pending_tool_outputs.get(None) + if pending is not None: + pending[base_tool.name] = text + + return { + "content": [{"type": "text", "text": text}], + "isError": not result.success, + } + + +def _mcp_error(message: str) -> dict[str, Any]: + return { + "content": [ + {"type": "text", "text": json.dumps({"error": message, "type": "error"})} + ], + "isError": True, + } + + +def create_tool_handler(base_tool: BaseTool): + """Create an async handler function for a BaseTool. + + This wraps the existing BaseTool._execute method to be compatible + with the Claude Agent SDK MCP tool format. + + Long-running tools (``is_long_running=True``) are delegated to the + non-SDK background infrastructure via a callback set in the execution + context. The callback persists the operation in Redis (stream_registry) + so results survive page refreshes and pod restarts. + """ + + async def tool_handler(args: dict[str, Any]) -> dict[str, Any]: + """Execute the wrapped tool and return MCP-formatted response.""" + user_id, session = get_execution_context() + + if session is None: + return _mcp_error("No session context available") + + # --- Long-running: delegate to non-SDK background infrastructure --- + if base_tool.is_long_running: + callback = _long_running_callback.get(None) + if callback: + try: + return await callback(base_tool.name, args, session) + except Exception as e: + logger.error( + f"Long-running callback failed for {base_tool.name}: {e}", + exc_info=True, + ) + return _mcp_error(f"Failed to start {base_tool.name}: {e}") + # No callback — fall through to synchronous execution + logger.warning( + f"[SDK] No long-running callback for {base_tool.name}, " + f"executing synchronously (may block)" + ) + + # --- Normal (fast) tool: execute synchronously --- + try: + return await _execute_tool_sync(base_tool, user_id, session, args) + except Exception as e: + logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True) + return _mcp_error(f"Failed to execute {base_tool.name}: {e}") + + return tool_handler + + +def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]: + """Build a JSON Schema input schema for a tool.""" + return { + "type": "object", + "properties": base_tool.parameters.get("properties", {}), + "required": base_tool.parameters.get("required", []), + } + + +async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]: + """Read a file with optional offset/limit. Restricted to SDK working directory. + + After reading, the file is deleted to prevent accumulation in long-running pods. + """ + file_path = args.get("file_path", "") + offset = args.get("offset", 0) + limit = args.get("limit", 2000) + + # Security: only allow reads under ~/.claude/projects/**/tool-results/ + real_path = os.path.realpath(file_path) + if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path: + return { + "content": [{"type": "text", "text": f"Access denied: {file_path}"}], + "isError": True, + } + + try: + with open(real_path) as f: + selected = list(itertools.islice(f, offset, offset + limit)) + content = "".join(selected) + # Cleanup happens in _cleanup_sdk_tool_results after session ends; + # don't delete here — the SDK may read in multiple chunks. + return {"content": [{"type": "text", "text": content}], "isError": False} + except FileNotFoundError: + return { + "content": [{"type": "text", "text": f"File not found: {file_path}"}], + "isError": True, + } + except Exception as e: + return { + "content": [{"type": "text", "text": f"Error reading file: {e}"}], + "isError": True, + } + + +_READ_TOOL_NAME = "Read" +_READ_TOOL_DESCRIPTION = ( + "Read a file from the local filesystem. " + "Use offset and limit to read specific line ranges for large files." +) +_READ_TOOL_SCHEMA = { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The absolute path to the file to read", + }, + "offset": { + "type": "integer", + "description": "Line number to start reading from (0-indexed). Default: 0", + }, + "limit": { + "type": "integer", + "description": "Number of lines to read. Default: 2000", + }, + }, + "required": ["file_path"], +} + + +# Create the MCP server configuration +def create_copilot_mcp_server(): + """Create an in-process MCP server configuration for CoPilot tools. + + This can be passed to ClaudeAgentOptions.mcp_servers. + + Note: The actual SDK MCP server creation depends on the claude-agent-sdk + package being available. This function returns the configuration that + can be used with the SDK. + """ + try: + from claude_agent_sdk import create_sdk_mcp_server, tool + + # Create decorated tool functions + sdk_tools = [] + + for tool_name, base_tool in TOOL_REGISTRY.items(): + handler = create_tool_handler(base_tool) + decorated = tool( + tool_name, + base_tool.description, + _build_input_schema(base_tool), + )(handler) + sdk_tools.append(decorated) + + # Add the Read tool so the SDK can read back oversized tool results + read_tool = tool( + _READ_TOOL_NAME, + _READ_TOOL_DESCRIPTION, + _READ_TOOL_SCHEMA, + )(_read_file_handler) + sdk_tools.append(read_tool) + + server = create_sdk_mcp_server( + name=MCP_SERVER_NAME, + version="1.0.0", + tools=sdk_tools, + ) + + return server + + except ImportError: + # Let ImportError propagate so service.py handles the fallback + raise + + +# SDK built-in tools allowed within the workspace directory. +# Security hooks validate that file paths stay within sdk_cwd. +# Bash is NOT included — use the sandboxed MCP bash_exec tool instead, +# which provides kernel-level network isolation via unshare --net. +# Task allows spawning sub-agents (rate-limited by security hooks). +# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk. +_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"] + +# SDK built-in tools that must be explicitly blocked. +# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level +# network isolation (unshare --net) instead. +# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.). +# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead. +SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"] + +# Tools that are blocked entirely in security hooks (defence-in-depth). +# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms. +BLOCKED_TOOLS = { + *SDK_DISALLOWED_TOOLS, + "bash", + "shell", + "exec", + "terminal", + "command", +} + +# Tools allowed only when their path argument stays within the SDK workspace. +# The SDK uses these to handle oversized tool results (writes to tool-results/ +# files, then reads them back) and for workspace file operations. +WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"} + +# Dangerous patterns in tool inputs +DANGEROUS_PATTERNS = [ + r"sudo", + r"rm\s+-rf", + r"dd\s+if=", + r"/etc/passwd", + r"/etc/shadow", + r"chmod\s+777", + r"curl\s+.*\|.*sh", + r"wget\s+.*\|.*sh", + r"eval\s*\(", + r"exec\s*\(", + r"__import__", + r"os\.system", + r"subprocess", +] + +# List of tool names for allowed_tools configuration +# Include MCP tools, the MCP Read tool for oversized results, +# and SDK built-in file tools for workspace operations. +COPILOT_TOOL_NAMES = [ + *[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()], + f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}", + *_SDK_BUILTIN_TOOLS, +] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py b/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py new file mode 100644 index 0000000000..aaa5609227 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py @@ -0,0 +1,356 @@ +"""JSONL transcript management for stateless multi-turn resume. + +The Claude Code CLI persists conversations as JSONL files (one JSON object per +line). When the SDK's ``Stop`` hook fires we read this file, strip bloat +(progress entries, metadata), and upload the result to bucket storage. On the +next turn we download the transcript, write it to a temp file, and pass +``--resume`` so the CLI can reconstruct the full conversation. + +Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local +filesystem for self-hosted) — no DB column needed. +""" + +import json +import logging +import os +import re + +logger = logging.getLogger(__name__) + +# UUIDs are hex + hyphens; strip everything else to prevent path injection. +_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]") + +# Entry types that can be safely removed from the transcript without breaking +# the parentUuid conversation tree that ``--resume`` relies on. +# - progress: UI progress ticks, no message content (avg 97KB for agent_progress) +# - file-history-snapshot: undo tracking metadata +# - queue-operation: internal queue bookkeeping +# - summary: session summaries +# - pr-link: PR link metadata +STRIPPABLE_TYPES = frozenset( + {"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"} +) + +# Workspace storage constants — deterministic path from session_id. +TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" + + +# --------------------------------------------------------------------------- +# Progress stripping +# --------------------------------------------------------------------------- + + +def strip_progress_entries(content: str) -> str: + """Remove progress/metadata entries from a JSONL transcript. + + Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents + any remaining child entries so the ``parentUuid`` chain stays intact. + Typically reduces transcript size by ~30%. + """ + lines = content.strip().split("\n") + + entries: list[dict] = [] + for line in lines: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + # Keep unparseable lines as-is (safety) + entries.append({"_raw": line}) + + stripped_uuids: set[str] = set() + uuid_to_parent: dict[str, str] = {} + kept: list[dict] = [] + + for entry in entries: + if "_raw" in entry: + kept.append(entry) + continue + uid = entry.get("uuid", "") + parent = entry.get("parentUuid", "") + entry_type = entry.get("type", "") + + if uid: + uuid_to_parent[uid] = parent + + if entry_type in STRIPPABLE_TYPES: + if uid: + stripped_uuids.add(uid) + else: + kept.append(entry) + + # Reparent: walk up chain through stripped entries to find surviving ancestor + for entry in kept: + if "_raw" in entry: + continue + parent = entry.get("parentUuid", "") + original_parent = parent + while parent in stripped_uuids: + parent = uuid_to_parent.get(parent, "") + if parent != original_parent: + entry["parentUuid"] = parent + + result_lines: list[str] = [] + for entry in kept: + if "_raw" in entry: + result_lines.append(entry["_raw"]) + else: + result_lines.append(json.dumps(entry, separators=(",", ":"))) + + return "\n".join(result_lines) + "\n" + + +# --------------------------------------------------------------------------- +# Local file I/O (read from CLI's JSONL, write temp file for --resume) +# --------------------------------------------------------------------------- + + +def read_transcript_file(transcript_path: str) -> str | None: + """Read a JSONL transcript file from disk. + + Returns the raw JSONL content, or ``None`` if the file is missing, empty, + or only contains metadata (≤2 lines with no conversation messages). + """ + if not transcript_path or not os.path.isfile(transcript_path): + logger.debug(f"[Transcript] File not found: {transcript_path}") + return None + + try: + with open(transcript_path) as f: + content = f.read() + + if not content.strip(): + logger.debug(f"[Transcript] Empty file: {transcript_path}") + return None + + lines = content.strip().split("\n") + if len(lines) < 3: + # Raw files with ≤2 lines are metadata-only + # (queue-operation + file-history-snapshot, no conversation). + logger.debug( + f"[Transcript] Too few lines ({len(lines)}): {transcript_path}" + ) + return None + + # Quick structural validation — parse first and last lines. + json.loads(lines[0]) + json.loads(lines[-1]) + + logger.info( + f"[Transcript] Read {len(lines)} lines, " + f"{len(content)} bytes from {transcript_path}" + ) + return content + + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}") + return None + + +def _sanitize_id(raw_id: str, max_len: int = 36) -> str: + """Sanitize an ID for safe use in file paths. + + Session/user IDs are expected to be UUIDs (hex + hyphens). Strip + everything else and truncate to *max_len* so the result cannot introduce + path separators or other special characters. + """ + cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len] + return cleaned or "unknown" + + +_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") + + +def write_transcript_to_tempfile( + transcript_content: str, + session_id: str, + cwd: str, +) -> str | None: + """Write JSONL transcript to a temp file inside *cwd* for ``--resume``. + + The file lives in the session working directory so it is cleaned up + automatically when the session ends. + + Returns the absolute path to the file, or ``None`` on failure. + """ + # Validate cwd is under the expected sandbox prefix (CodeQL sanitizer). + real_cwd = os.path.realpath(cwd) + if not real_cwd.startswith(_SAFE_CWD_PREFIX): + logger.warning(f"[Transcript] cwd outside sandbox: {cwd}") + return None + + try: + os.makedirs(real_cwd, exist_ok=True) + safe_id = _sanitize_id(session_id, max_len=8) + jsonl_path = os.path.realpath( + os.path.join(real_cwd, f"transcript-{safe_id}.jsonl") + ) + if not jsonl_path.startswith(real_cwd): + logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}") + return None + + with open(jsonl_path, "w") as f: + f.write(transcript_content) + + logger.info(f"[Transcript] Wrote resume file: {jsonl_path}") + return jsonl_path + + except OSError as e: + logger.warning(f"[Transcript] Failed to write resume file: {e}") + return None + + +def validate_transcript(content: str | None) -> bool: + """Check that a transcript has actual conversation messages. + + A valid transcript for resume needs at least one user message and one + assistant message (not just queue-operation / file-history-snapshot + metadata). + """ + if not content or not content.strip(): + return False + + lines = content.strip().split("\n") + if len(lines) < 2: + return False + + has_user = False + has_assistant = False + + for line in lines: + try: + entry = json.loads(line) + msg_type = entry.get("type") + if msg_type == "user": + has_user = True + elif msg_type == "assistant": + has_assistant = True + except json.JSONDecodeError: + return False + + return has_user and has_assistant + + +# --------------------------------------------------------------------------- +# Bucket storage (GCS / local via WorkspaceStorageBackend) +# --------------------------------------------------------------------------- + + +def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: + """Return (workspace_id, file_id, filename) for a session's transcript. + + Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl`` + IDs are sanitized to hex+hyphen to prevent path traversal. + """ + return ( + TRANSCRIPT_STORAGE_PREFIX, + _sanitize_id(user_id), + f"{_sanitize_id(session_id)}.jsonl", + ) + + +def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: + """Build the full storage path string that ``retrieve()`` expects. + + ``store()`` returns a path like ``gcs://bucket/workspaces/...`` or + ``local://workspace_id/file_id/filename``. Since we use deterministic + arguments we can reconstruct the same path for download/delete without + having stored the return value. + """ + from backend.util.workspace_storage import GCSWorkspaceStorage + + wid, fid, fname = _storage_path_parts(user_id, session_id) + + if isinstance(backend, GCSWorkspaceStorage): + blob = f"workspaces/{wid}/{fid}/{fname}" + return f"gcs://{backend.bucket_name}/{blob}" + else: + # LocalWorkspaceStorage returns local://{relative_path} + return f"local://{wid}/{fid}/{fname}" + + +async def upload_transcript(user_id: str, session_id: str, content: str) -> None: + """Strip progress entries and upload transcript to bucket storage. + + Safety: only overwrites when the new (stripped) transcript is larger than + what is already stored. Since JSONL is append-only, the latest transcript + is always the longest. This prevents a slow/stale background task from + clobbering a newer upload from a concurrent turn. + """ + from backend.util.workspace_storage import get_workspace_storage + + stripped = strip_progress_entries(content) + if not validate_transcript(stripped): + logger.warning( + f"[Transcript] Skipping upload — stripped content is not a valid " + f"transcript for session {session_id}" + ) + return + + storage = await get_workspace_storage() + wid, fid, fname = _storage_path_parts(user_id, session_id) + encoded = stripped.encode("utf-8") + new_size = len(encoded) + + # Check existing transcript size to avoid overwriting newer with older + path = _build_storage_path(user_id, session_id, storage) + try: + existing = await storage.retrieve(path) + if len(existing) >= new_size: + logger.info( + f"[Transcript] Skipping upload — existing transcript " + f"({len(existing)}B) >= new ({new_size}B) for session " + f"{session_id}" + ) + return + except (FileNotFoundError, Exception): + pass # No existing transcript or retrieval error — proceed with upload + + await storage.store( + workspace_id=wid, + file_id=fid, + filename=fname, + content=encoded, + ) + logger.info( + f"[Transcript] Uploaded {new_size} bytes " + f"(stripped from {len(content)}) for session {session_id}" + ) + + +async def download_transcript(user_id: str, session_id: str) -> str | None: + """Download transcript from bucket storage. + + Returns the JSONL content string, or ``None`` if not found. + """ + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + data = await storage.retrieve(path) + content = data.decode("utf-8") + logger.info( + f"[Transcript] Downloaded {len(content)} bytes for session {session_id}" + ) + return content + except FileNotFoundError: + logger.debug(f"[Transcript] No transcript in storage for {session_id}") + return None + except Exception as e: + logger.warning(f"[Transcript] Failed to download transcript: {e}") + return None + + +async def delete_transcript(user_id: str, session_id: str) -> None: + """Delete transcript from bucket storage (e.g. after resume failure).""" + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + await storage.delete(path) + logger.info(f"[Transcript] Deleted transcript for session {session_id}") + except Exception as e: + logger.warning(f"[Transcript] Failed to delete transcript: {e}") diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 193566ea01..cb5591e6d0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -245,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str: return DEFAULT_SYSTEM_PROMPT.format(users_information=context) -async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: +async def _build_system_prompt( + user_id: str | None, has_conversation_history: bool = False +) -> tuple[str, Any]: """Build the full system prompt including business understanding if available. Args: - user_id: The user ID for fetching business understanding - If "default" and this is the user's first session, will use "onboarding" instead. + user_id: The user ID for fetching business understanding. + has_conversation_history: Whether there's existing conversation history. + If True, we don't tell the model to greet/introduce (since they're + already in a conversation). Returns: Tuple of (compiled prompt string, business understanding object) @@ -266,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: if understanding: context = format_understanding_for_prompt(understanding) + elif has_conversation_history: + context = "No prior understanding saved yet. Continue the existing conversation naturally." else: context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" @@ -374,7 +380,6 @@ async def stream_chat_completion( Raises: NotFoundError: If session_id is invalid - ValueError: If max_context_messages is exceeded """ completion_start = time.monotonic() @@ -459,8 +464,9 @@ async def stream_chat_completion( # Generate title for new sessions on first user message (non-blocking) # Check: is_user_message, no title yet, and this is the first user message - if is_user_message and message and not session.title: - user_messages = [m for m in session.messages if m.role == "user"] + user_messages = [m for m in session.messages if m.role == "user"] + first_user_msg = message or (user_messages[0].content if user_messages else None) + if is_user_message and first_user_msg and not session.title: if len(user_messages) == 1: # First user message - generate title in background import asyncio @@ -468,7 +474,7 @@ async def stream_chat_completion( # Capture only the values we need (not the session object) to avoid # stale data issues when the main flow modifies the session captured_session_id = session_id - captured_message = message + captured_message = first_user_msg captured_user_id = user_id async def _update_title(): @@ -1237,7 +1243,7 @@ async def _stream_chat_chunks( total_time = (time_module.perf_counter() - stream_chunks_start) * 1000 logger.info( - f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; " + f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; " f"session={session.session_id}, user={session.user_id}", extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, ) @@ -1245,6 +1251,7 @@ async def _stream_chat_chunks( return except Exception as e: last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: retry_count += 1 # Calculate delay with exponential backoff @@ -1260,12 +1267,27 @@ async def _stream_chat_chunks( continue # Retry the stream else: # Non-retryable error or max retries exceeded - logger.error( - f"Error in stream (not retrying): {e!s}", - exc_info=True, + _log_api_error( + error=e, + context="stream (not retrying)", + session_id=session.session_id if session else None, + message_count=len(messages) if messages else None, + model=model, + retry_count=retry_count, ) error_code = None error_text = str(e) + + error_details = _extract_api_error_details(e) + if error_details.get("response_body"): + body = error_details["response_body"] + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict) and err.get("message"): + error_text = err["message"] + elif body.get("message"): + error_text = body["message"] + if _is_region_blocked_error(e): error_code = "MODEL_NOT_AVAILABLE_REGION" error_text = ( @@ -1282,9 +1304,13 @@ async def _stream_chat_chunks( # If we exit the retry loop without returning, it means we exhausted retries if last_error: - logger.error( - f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}", - exc_info=True, + _log_api_error( + error=last_error, + context=f"stream (max retries {MAX_RETRIES} exceeded)", + session_id=session.session_id if session else None, + message_count=len(messages) if messages else None, + model=model, + retry_count=MAX_RETRIES, ) yield StreamError(errorText=f"Max retries exceeded: {last_error!s}") yield StreamFinish() @@ -1857,6 +1883,7 @@ async def _generate_llm_continuation( break # Success, exit retry loop except Exception as e: last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: retry_count += 1 delay = min( @@ -1870,17 +1897,25 @@ async def _generate_llm_continuation( await asyncio.sleep(delay) continue else: - # Non-retryable error - log and exit gracefully - logger.error( - f"Non-retryable error in LLM continuation: {e!s}", - exc_info=True, + # Non-retryable error - log details and exit gracefully + _log_api_error( + error=e, + context="LLM continuation (not retrying)", + session_id=session_id, + message_count=len(messages) if messages else None, + model=config.model, + retry_count=retry_count, ) return if last_error: - logger.error( - f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. " - f"Last error: {last_error!s}" + _log_api_error( + error=last_error, + context=f"LLM continuation (max retries {MAX_RETRIES} exceeded)", + session_id=session_id, + message_count=len(messages) if messages else None, + model=config.model, + retry_count=MAX_RETRIES, ) return @@ -1920,6 +1955,91 @@ async def _generate_llm_continuation( logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) +def _log_api_error( + error: Exception, + context: str, + session_id: str | None = None, + message_count: int | None = None, + model: str | None = None, + retry_count: int = 0, +) -> None: + """Log detailed API error information for debugging.""" + details = _extract_api_error_details(error) + details["context"] = context + details["session_id"] = session_id + details["message_count"] = message_count + details["model"] = model + details["retry_count"] = retry_count + + if isinstance(error, RateLimitError): + logger.warning(f"Rate limit error in {context}: {details}", exc_info=error) + elif isinstance(error, APIConnectionError): + logger.warning(f"API connection error in {context}: {details}", exc_info=error) + elif isinstance(error, APIStatusError) and error.status_code >= 500: + logger.error(f"API server error (5xx) in {context}: {details}", exc_info=error) + else: + logger.error(f"API error in {context}: {details}", exc_info=error) + + +def _extract_api_error_details(error: Exception) -> dict[str, Any]: + """Extract detailed information from OpenAI/OpenRouter API errors.""" + error_msg = str(error) + details: dict[str, Any] = { + "error_type": type(error).__name__, + "error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg, + } + + if hasattr(error, "code"): + details["code"] = getattr(error, "code", None) + if hasattr(error, "param"): + details["param"] = getattr(error, "param", None) + + if isinstance(error, APIStatusError): + details["status_code"] = error.status_code + details["request_id"] = getattr(error, "request_id", None) + + if hasattr(error, "body") and error.body: + details["response_body"] = _sanitize_error_body(error.body) + + if hasattr(error, "response") and error.response: + headers = error.response.headers + details["openrouter_provider"] = headers.get("x-openrouter-provider") + details["openrouter_model"] = headers.get("x-openrouter-model") + details["retry_after"] = headers.get("retry-after") + details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining") + + return details + + +def _sanitize_error_body( + body: Any, max_length: int = 2000 +) -> dict[str, Any] | str | None: + """Extract only safe fields from error response body to avoid logging sensitive data.""" + if not isinstance(body, dict): + # Non-dict bodies (e.g., HTML error pages) - return truncated string + if body is not None: + body_str = str(body) + if len(body_str) > max_length: + return body_str[:max_length] + "...[truncated]" + return body_str + return None + + safe_fields = ("message", "type", "code", "param", "error") + sanitized: dict[str, Any] = {} + + for field in safe_fields: + if field in body: + value = body[field] + if field == "error" and isinstance(value, dict): + sanitized[field] = _sanitize_error_body(value, max_length) + elif isinstance(value, str) and len(value) > max_length: + sanitized[field] = value[:max_length] + "...[truncated]" + else: + sanitized[field] = value + + return sanitized if sanitized else None + + async def _generate_llm_continuation_with_streaming( session_id: str, user_id: str | None, diff --git a/autogpt_platform/backend/backend/api/features/chat/service_test.py b/autogpt_platform/backend/backend/api/features/chat/service_test.py index 70f27af14f..b2fc82b790 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/service_test.py @@ -1,3 +1,4 @@ +import asyncio import logging from os import getenv @@ -11,6 +12,8 @@ from .response_model import ( StreamTextDelta, StreamToolOutputAvailable, ) +from .sdk import service as sdk_service +from .sdk.transcript import download_transcript logger = logging.getLogger(__name__) @@ -80,3 +83,96 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user session = await get_chat_session(session.session_id) assert session, "Session not found" assert session.usage, "Usage is empty" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sdk_resume_multi_turn(setup_test_user, test_user_id): + """Test that the SDK --resume path captures and uses transcripts across turns. + + Turn 1: Send a message containing a unique keyword. + Turn 2: Ask the model to recall that keyword — proving the transcript was + persisted and restored via --resume. + """ + api_key: str | None = getenv("OPEN_ROUTER_API_KEY") + if not api_key: + return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test") + + from .config import ChatConfig + + cfg = ChatConfig() + if not cfg.claude_agent_use_resume: + return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test") + + session = await create_chat_session(test_user_id) + session = await upsert_chat_session(session) + + # --- Turn 1: send a message with a unique keyword --- + keyword = "ZEPHYR42" + turn1_msg = ( + f"Please remember this special keyword: {keyword}. " + "Just confirm you've noted it, keep your response brief." + ) + turn1_text = "" + turn1_errors: list[str] = [] + turn1_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn1_msg, + user_id=test_user_id, + ): + if isinstance(chunk, StreamTextDelta): + turn1_text += chunk.delta + elif isinstance(chunk, StreamError): + turn1_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn1_ended = True + + assert turn1_ended, "Turn 1 did not finish" + assert not turn1_errors, f"Turn 1 errors: {turn1_errors}" + assert turn1_text, "Turn 1 produced no text" + + # Wait for background upload task to complete (retry up to 5s) + transcript = None + for _ in range(10): + await asyncio.sleep(0.5) + transcript = await download_transcript(test_user_id, session.session_id) + if transcript: + break + assert transcript, ( + "Transcript was not uploaded to bucket after turn 1 — " + "Stop hook may not have fired or transcript was too small" + ) + logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes") + + # Reload session for turn 2 + session = await get_chat_session(session.session_id, test_user_id) + assert session, "Session not found after turn 1" + + # --- Turn 2: ask model to recall the keyword --- + turn2_msg = "What was the special keyword I asked you to remember?" + turn2_text = "" + turn2_errors: list[str] = [] + turn2_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn2_msg, + user_id=test_user_id, + session=session, + ): + if isinstance(chunk, StreamTextDelta): + turn2_text += chunk.delta + elif isinstance(chunk, StreamError): + turn2_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn2_ended = True + + assert turn2_ended, "Turn 2 did not finish" + assert not turn2_errors, f"Turn 2 errors: {turn2_errors}" + assert turn2_text, "Turn 2 produced no text" + assert keyword in turn2_text, ( + f"Model did not recall keyword '{keyword}' in turn 2. " + f"Response: {turn2_text[:200]}" + ) + logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}") diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index abc34b1fc9..671aefc7ba 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -814,6 +814,28 @@ async def get_active_task_for_session( if task_user_id and user_id != task_user_id: continue + # Auto-expire stale tasks that exceeded stream_timeout + created_at_str = meta.get("created_at", "") + if created_at_str: + try: + created_at = datetime.fromisoformat(created_at_str) + age_seconds = ( + datetime.now(timezone.utc) - created_at + ).total_seconds() + if age_seconds > config.stream_timeout: + logger.warning( + f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... " + f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)" + ) + await mark_task_completed(task_id, "failed") + continue + except (ValueError, TypeError): + pass + + logger.info( + f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..." + ) + # Get the last message ID from Redis Stream stream_key = _get_task_stream_key(task_id) last_id = "0-0" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index dcbc35ef37..1ab4f720bb 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -9,9 +9,12 @@ from backend.api.features.chat.tracking import track_tool_called from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool +from .bash_exec import BashExecTool +from .check_operation_status import CheckOperationStatusTool from .create_agent import CreateAgentTool from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool +from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool from .find_agent import FindAgentTool from .find_block import FindBlockTool from .find_library_agent import FindLibraryAgentTool @@ -19,6 +22,7 @@ from .get_doc_page import GetDocPageTool from .run_agent import RunAgentTool from .run_block import RunBlockTool from .search_docs import SearchDocsTool +from .web_fetch import WebFetchTool from .workspace_files import ( DeleteWorkspaceFileTool, ListWorkspaceFilesTool, @@ -43,8 +47,17 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "run_agent": RunAgentTool(), "run_block": RunBlockTool(), "view_agent_output": AgentOutputTool(), + "check_operation_status": CheckOperationStatusTool(), "search_docs": SearchDocsTool(), "get_doc_page": GetDocPageTool(), + # Web fetch for safe URL retrieval + "web_fetch": WebFetchTool(), + # Sandboxed code execution (bubblewrap) + "bash_exec": BashExecTool(), + # Persistent workspace tools (cloud storage, survives across sessions) + # Feature request tools + "search_feature_requests": SearchFeatureRequestsTool(), + "create_feature_request": CreateFeatureRequestTool(), # Workspace tools for CoPilot file operations "list_workspace_files": ListWorkspaceFilesTool(), "read_workspace_file": ReadWorkspaceFileTool(), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py b/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py new file mode 100644 index 0000000000..da9d8bf3fa --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py @@ -0,0 +1,131 @@ +"""Bash execution tool — run shell commands in a bubblewrap sandbox. + +Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.). +Safety comes from OS-level isolation (bubblewrap): only system dirs visible +read-only, writable workspace only, clean env, no network. + +Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not +available (e.g. macOS development). +""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + BashExecResponse, + ErrorResponse, + ToolResponseBase, +) +from backend.api.features.chat.tools.sandbox import ( + get_workspace_dir, + has_full_sandbox, + run_sandboxed, +) + +logger = logging.getLogger(__name__) + + +class BashExecTool(BaseTool): + """Execute Bash commands in a bubblewrap sandbox.""" + + @property + def name(self) -> str: + return "bash_exec" + + @property + def description(self) -> str: + if not has_full_sandbox(): + return ( + "Bash execution is DISABLED — bubblewrap sandbox is not " + "available on this platform. Do not call this tool." + ) + return ( + "Execute a Bash command or script in a bubblewrap sandbox. " + "Full Bash scripting is supported (loops, conditionals, pipes, " + "functions, etc.). " + "The sandbox shares the same working directory as the SDK Read/Write " + "tools — files created by either are accessible to both. " + "SECURITY: Only system directories (/usr, /bin, /lib, /etc) are " + "visible read-only, the per-session workspace is the only writable " + "path, environment variables are wiped (no secrets), all network " + "access is blocked at the kernel level, and resource limits are " + "enforced (max 64 processes, 512MB memory, 50MB file size). " + "Application code, configs, and other directories are NOT accessible. " + "To fetch web content, use the web_fetch tool instead. " + "Execution is killed after the timeout (default 30s, max 120s). " + "Returns stdout and stderr. " + "Useful for file manipulation, data processing with Unix tools " + "(grep, awk, sed, jq, etc.), and running shell scripts." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash command or script to execute.", + }, + "timeout": { + "type": "integer", + "description": ( + "Max execution time in seconds (default 30, max 120)." + ), + "default": 30, + }, + }, + "required": ["command"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + session_id = session.session_id if session else None + + if not has_full_sandbox(): + return ErrorResponse( + message="bash_exec requires bubblewrap sandbox (Linux only).", + error="sandbox_unavailable", + session_id=session_id, + ) + + command: str = (kwargs.get("command") or "").strip() + timeout: int = kwargs.get("timeout", 30) + + if not command: + return ErrorResponse( + message="No command provided.", + error="empty_command", + session_id=session_id, + ) + + workspace = get_workspace_dir(session_id or "default") + + stdout, stderr, exit_code, timed_out = await run_sandboxed( + command=["bash", "-c", command], + cwd=workspace, + timeout=timeout, + ) + + return BashExecResponse( + message=( + "Execution timed out" + if timed_out + else f"Command executed (exit {exit_code})" + ), + stdout=stdout, + stderr=stderr, + exit_code=exit_code, + timed_out=timed_out, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py b/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py new file mode 100644 index 0000000000..b8ec770fd0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py @@ -0,0 +1,127 @@ +"""CheckOperationStatusTool — query the status of a long-running operation.""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + ResponseType, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class OperationStatusResponse(ToolResponseBase): + """Response for check_operation_status tool.""" + + type: ResponseType = ResponseType.OPERATION_STATUS + task_id: str + operation_id: str + status: str # "running", "completed", "failed" + tool_name: str | None = None + message: str = "" + + +class CheckOperationStatusTool(BaseTool): + """Check the status of a long-running operation (create_agent, edit_agent, etc.). + + The CoPilot uses this tool to report back to the user whether an + operation that was started earlier has completed, failed, or is still + running. + """ + + @property + def name(self) -> str: + return "check_operation_status" + + @property + def description(self) -> str: + return ( + "Check the current status of a long-running operation such as " + "create_agent or edit_agent. Accepts either an operation_id or " + "task_id from a previous operation_started response. " + "Returns the current status: running, completed, or failed." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "operation_id": { + "type": "string", + "description": ( + "The operation_id from an operation_started response." + ), + }, + "task_id": { + "type": "string", + "description": ( + "The task_id from an operation_started response. " + "Used as fallback if operation_id is not provided." + ), + }, + }, + "required": [], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + from backend.api.features.chat import stream_registry + + operation_id = (kwargs.get("operation_id") or "").strip() + task_id = (kwargs.get("task_id") or "").strip() + + if not operation_id and not task_id: + return ErrorResponse( + message="Please provide an operation_id or task_id.", + error="missing_parameter", + ) + + task = None + if operation_id: + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None and task_id: + task = await stream_registry.get_task(task_id) + + if task is None: + # Task not in Redis — it may have already expired (TTL). + # Check conversation history for the result instead. + return ErrorResponse( + message=( + "Operation not found — it may have already completed and " + "expired from the status tracker. Check the conversation " + "history for the result." + ), + error="not_found", + ) + + status_messages = { + "running": ( + f"The {task.tool_name or 'operation'} is still running. " + "Please wait for it to complete." + ), + "completed": ( + f"The {task.tool_name or 'operation'} has completed successfully." + ), + "failed": f"The {task.tool_name or 'operation'} has failed.", + } + + return OperationStatusResponse( + task_id=task.task_id, + operation_id=task.operation_id, + status=task.status, + tool_name=task.tool_name, + message=status_messages.get(task.status, f"Status: {task.status}"), + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py new file mode 100644 index 0000000000..95f1eb1fbe --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py @@ -0,0 +1,448 @@ +"""Feature request tools - search and create feature requests via Linear.""" + +import logging +from typing import Any + +from pydantic import SecretStr + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestInfo, + FeatureRequestSearchResponse, + NoResultsResponse, + ToolResponseBase, +) +from backend.blocks.linear._api import LinearClient +from backend.data.model import APIKeyCredentials +from backend.data.user import get_user_email_by_id +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + +MAX_SEARCH_RESULTS = 10 + +# GraphQL queries/mutations +SEARCH_ISSUES_QUERY = """ +query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) { + searchIssues(term: $term, filter: $filter, first: $first) { + nodes { + id + identifier + title + description + } + } +} +""" + +CUSTOMER_UPSERT_MUTATION = """ +mutation CustomerUpsert($input: CustomerUpsertInput!) { + customerUpsert(input: $input) { + success + customer { + id + name + externalIds + } + } +} +""" + +ISSUE_CREATE_MUTATION = """ +mutation IssueCreate($input: IssueCreateInput!) { + issueCreate(input: $input) { + success + issue { + id + identifier + title + url + } + } +} +""" + +CUSTOMER_NEED_CREATE_MUTATION = """ +mutation CustomerNeedCreate($input: CustomerNeedCreateInput!) { + customerNeedCreate(input: $input) { + success + need { + id + body + customer { + id + name + } + issue { + id + identifier + title + url + } + } + } +} +""" + + +_settings: Settings | None = None + + +def _get_settings() -> Settings: + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def _get_linear_config() -> tuple[LinearClient, str, str]: + """Return a configured Linear client, project ID, and team ID. + + Raises RuntimeError if any required setting is missing. + """ + secrets = _get_settings().secrets + if not secrets.linear_api_key: + raise RuntimeError("LINEAR_API_KEY is not configured") + if not secrets.linear_feature_request_project_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured") + if not secrets.linear_feature_request_team_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_TEAM_ID is not configured") + + credentials = APIKeyCredentials( + id="system-linear", + provider="linear", + api_key=SecretStr(secrets.linear_api_key), + title="System Linear API Key", + ) + client = LinearClient(credentials=credentials) + return ( + client, + secrets.linear_feature_request_project_id, + secrets.linear_feature_request_team_id, + ) + + +class SearchFeatureRequestsTool(BaseTool): + """Tool for searching existing feature requests in Linear.""" + + @property + def name(self) -> str: + return "search_feature_requests" + + @property + def description(self) -> str: + return ( + "Search existing feature requests to check if a similar request " + "already exists before creating a new one. Returns matching feature " + "requests with their ID, title, and description." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search term to find matching feature requests.", + }, + }, + "required": ["query"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + query = kwargs.get("query", "").strip() + session_id = session.session_id if session else None + + if not query: + return ErrorResponse( + message="Please provide a search query.", + error="Missing query parameter", + session_id=session_id, + ) + + try: + client, project_id, _team_id = _get_linear_config() + data = await client.query( + SEARCH_ISSUES_QUERY, + { + "term": query, + "filter": { + "project": {"id": {"eq": project_id}}, + }, + "first": MAX_SEARCH_RESULTS, + }, + ) + + nodes = data.get("searchIssues", {}).get("nodes", []) + + if not nodes: + return NoResultsResponse( + message=f"No feature requests found matching '{query}'.", + suggestions=[ + "Try different keywords", + "Use broader search terms", + "You can create a new feature request if none exists", + ], + session_id=session_id, + ) + + results = [ + FeatureRequestInfo( + id=node["id"], + identifier=node["identifier"], + title=node["title"], + description=node.get("description"), + ) + for node in nodes + ] + + return FeatureRequestSearchResponse( + message=f"Found {len(results)} feature request(s) matching '{query}'.", + results=results, + count=len(results), + query=query, + session_id=session_id, + ) + except Exception as e: + logger.exception("Failed to search feature requests") + return ErrorResponse( + message="Failed to search feature requests.", + error=str(e), + session_id=session_id, + ) + + +class CreateFeatureRequestTool(BaseTool): + """Tool for creating feature requests (or adding needs to existing ones).""" + + @property + def name(self) -> str: + return "create_feature_request" + + @property + def description(self) -> str: + return ( + "Create a new feature request or add a customer need to an existing one. " + "Always search first with search_feature_requests to avoid duplicates. " + "If a matching request exists, pass its ID as existing_issue_id to add " + "the user's need to it instead of creating a duplicate." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title for the feature request.", + }, + "description": { + "type": "string", + "description": "Detailed description of what the user wants and why.", + }, + "existing_issue_id": { + "type": "string", + "description": ( + "If adding a need to an existing feature request, " + "provide its Linear issue ID (from search results). " + "Omit to create a new feature request." + ), + }, + }, + "required": ["title", "description"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _find_or_create_customer( + self, client: LinearClient, user_id: str, name: str + ) -> dict: + """Find existing customer by user_id or create a new one via upsert. + + Args: + client: Linear API client. + user_id: Stable external ID used to deduplicate customers. + name: Human-readable display name (e.g. the user's email). + """ + data = await client.mutate( + CUSTOMER_UPSERT_MUTATION, + { + "input": { + "name": name, + "externalId": user_id, + }, + }, + ) + result = data.get("customerUpsert", {}) + if not result.get("success"): + raise RuntimeError(f"Failed to upsert customer: {data}") + return result["customer"] + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + title = kwargs.get("title", "").strip() + description = kwargs.get("description", "").strip() + existing_issue_id = kwargs.get("existing_issue_id") + session_id = session.session_id if session else None + + if not title or not description: + return ErrorResponse( + message="Both title and description are required.", + error="Missing required parameters", + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="Authentication required to create feature requests.", + error="Missing user_id", + session_id=session_id, + ) + + try: + client, project_id, team_id = _get_linear_config() + except Exception as e: + logger.exception("Failed to initialize Linear client") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Resolve a human-readable name (email) for the Linear customer record. + # Fall back to user_id if the lookup fails or returns None. + try: + customer_display_name = await get_user_email_by_id(user_id) or user_id + except Exception: + customer_display_name = user_id + + # Step 1: Find or create customer for this user + try: + customer = await self._find_or_create_customer( + client, user_id, customer_display_name + ) + customer_id = customer["id"] + customer_name = customer["name"] + except Exception as e: + logger.exception("Failed to upsert customer in Linear") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Step 2: Create or reuse issue + issue_id: str | None = None + issue_identifier: str | None = None + if existing_issue_id: + # Add need to existing issue - we still need the issue details for response + is_new_issue = False + issue_id = existing_issue_id + else: + # Create new issue in the feature requests project + try: + data = await client.mutate( + ISSUE_CREATE_MUTATION, + { + "input": { + "title": title, + "description": description, + "teamId": team_id, + "projectId": project_id, + }, + }, + ) + result = data.get("issueCreate", {}) + if not result.get("success"): + return ErrorResponse( + message="Failed to create feature request issue.", + error=str(data), + session_id=session_id, + ) + issue = result["issue"] + issue_id = issue["id"] + issue_identifier = issue.get("identifier") + except Exception as e: + logger.exception("Failed to create feature request issue") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + is_new_issue = True + + # Step 3: Create customer need on the issue + try: + data = await client.mutate( + CUSTOMER_NEED_CREATE_MUTATION, + { + "input": { + "customerId": customer_id, + "issueId": issue_id, + "body": description, + "priority": 0, + }, + }, + ) + need_result = data.get("customerNeedCreate", {}) + if not need_result.get("success"): + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(data), + details=orphaned, + session_id=session_id, + ) + need = need_result["need"] + issue_info = need["issue"] + except Exception as e: + logger.exception("Failed to create customer need") + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(e), + details=orphaned, + session_id=session_id, + ) + + return FeatureRequestCreatedResponse( + message=( + f"{'Created new feature request' if is_new_issue else 'Added your request to existing feature request'}: " + f"{issue_info['title']}." + ), + issue_id=issue_info["id"], + issue_identifier=issue_info["identifier"], + issue_title=issue_info["title"], + issue_url=issue_info.get("url", ""), + is_new_issue=is_new_issue, + customer_name=customer_name, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py new file mode 100644 index 0000000000..438725368f --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py @@ -0,0 +1,615 @@ +"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.api.features.chat.tools.feature_requests import ( + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +) +from backend.api.features.chat.tools.models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestSearchResponse, + NoResultsResponse, +) + +from ._test_data import make_session + +_TEST_USER_ID = "test-user-feature-requests" +_TEST_USER_EMAIL = "testuser@example.com" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +_FAKE_PROJECT_ID = "test-project-id" +_FAKE_TEAM_ID = "test-team-id" + + +def _mock_linear_config(*, query_return=None, mutate_return=None): + """Return a patched _get_linear_config that yields a mock LinearClient.""" + client = AsyncMock() + if query_return is not None: + client.query.return_value = query_return + if mutate_return is not None: + client.mutate.return_value = mutate_return + return ( + patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID), + ), + client, + ) + + +def _search_response(nodes: list[dict]) -> dict: + return {"searchIssues": {"nodes": nodes}} + + +def _customer_upsert_response( + customer_id: str = "cust-1", name: str = _TEST_USER_EMAIL, success: bool = True +) -> dict: + return { + "customerUpsert": { + "success": success, + "customer": {"id": customer_id, "name": name, "externalIds": [name]}, + } + } + + +def _issue_create_response( + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "issueCreate": { + "success": success, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + } + } + + +def _need_create_response( + need_id: str = "need-1", + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "customerNeedCreate": { + "success": success, + "need": { + "id": need_id, + "body": "description", + "customer": {"id": "cust-1", "name": _TEST_USER_EMAIL}, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + }, + } + } + + +# =========================================================================== +# SearchFeatureRequestsTool +# =========================================================================== + + +class TestSearchFeatureRequestsTool: + """Tests for SearchFeatureRequestsTool._execute.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_successful_search(self): + session = make_session(user_id=_TEST_USER_ID) + nodes = [ + { + "id": "id-1", + "identifier": "FR-1", + "title": "Dark mode", + "description": "Add dark mode support", + }, + { + "id": "id-2", + "identifier": "FR-2", + "title": "Dark theme", + "description": None, + }, + ] + patcher, _ = _mock_linear_config(query_return=_search_response(nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="dark mode" + ) + + assert isinstance(resp, FeatureRequestSearchResponse) + assert resp.count == 2 + assert resp.results[0].id == "id-1" + assert resp.results[1].identifier == "FR-2" + assert resp.query == "dark mode" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_results(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, _ = _mock_linear_config(query_return=_search_response([])) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="nonexistent" + ) + + assert isinstance(resp, NoResultsResponse) + assert "nonexistent" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_empty_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session, query=" ") + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "query" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_api_failure(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.query.side_effect = RuntimeError("Linear API down") + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Linear API down" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_malformed_node_returns_error(self): + """A node missing required keys should be caught by the try/except.""" + session = make_session(user_id=_TEST_USER_ID) + # Node missing 'identifier' key + bad_nodes = [{"id": "id-1", "title": "Missing identifier"}] + patcher, _ = _mock_linear_config(query_return=_search_response(bad_nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + +# =========================================================================== +# CreateFeatureRequestTool +# =========================================================================== + + +class TestCreateFeatureRequestTool: + """Tests for CreateFeatureRequestTool._execute.""" + + @pytest.fixture(autouse=True) + def _patch_email_lookup(self): + with patch( + "backend.api.features.chat.tools.feature_requests.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TEST_USER_EMAIL, + ): + yield + + # ---- Happy paths ------------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_create_new_issue(self): + """Full happy path: upsert customer -> create issue -> attach need.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + _need_create_response(), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="New Feature", + description="Please add this", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is True + assert resp.issue_identifier == "FR-1" + assert resp.customer_name == _TEST_USER_EMAIL + assert client.mutate.call_count == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_add_need_to_existing_issue(self): + """When existing_issue_id is provided, skip issue creation.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(issue_id="existing-1", identifier="FR-99"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Existing Feature", + description="Me too", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is False + assert resp.issue_id == "existing-1" + # Only 2 mutations: customer upsert + need create (no issue create) + assert client.mutate.call_count == 2 + + # ---- Validation errors ------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_title(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="", + description="some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_description(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Some title", + description="", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_user_id(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=None, + session=session, + title="Some title", + description="Some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "user_id" in resp.error.lower() + + # ---- Linear client init failure ---------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + # ---- Customer upsert failures ------------------------------------------ + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = RuntimeError("Customer API error") + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Customer API error" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.return_value = _customer_upsert_response(success=False) + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_malformed_response(self): + """Customer dict missing 'id' key should be caught.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + # success=True but customer has no 'id' + client.mutate.return_value = { + "customerUpsert": { + "success": True, + "customer": {"name": _TEST_USER_ID}, + } + } + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Issue creation failures ------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Issue create failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Issue create failed" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert "Failed to create feature request issue" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_malformed_response(self): + """issueCreate success=True but missing 'issue' key.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + {"issueCreate": {"success": True}}, # no 'issue' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Customer need attachment failures --------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_new_issue(self): + """Need creation fails after new issue was created -> orphaned issue info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-1", identifier="FR-10"), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Need attach failed" in resp.error + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-1" + assert resp.details["issue_identifier"] == "FR-10" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_existing_issue(self): + """Need creation fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_includes_orphaned_info(self): + """customerNeedCreate returns success=False -> includes orphaned issue.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-2", identifier="FR-20"), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-2" + assert resp.details["issue_identifier"] == "FR-20" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_existing_issue_no_details(self): + """customerNeedCreate fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_malformed_response(self): + """need_result missing 'need' key after success=True.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + {"customerNeedCreate": {"success": True}}, # no 'need' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "issue-1" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py index 6a8cfa9bbc..c51317cb62 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py @@ -7,7 +7,6 @@ from backend.api.features.chat.model import ChatSession from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase from backend.api.features.chat.tools.models import ( BlockInfoSummary, - BlockInputFieldInfo, BlockListResponse, ErrorResponse, NoResultsResponse, @@ -55,7 +54,8 @@ class FindBlockTool(BaseTool): "Blocks are reusable components that perform specific tasks like " "sending emails, making API calls, processing text, etc. " "IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. " - "The response includes each block's id, required_inputs, and input_schema." + "The response includes each block's id, name, and description. " + "Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it." ) @property @@ -124,7 +124,7 @@ class FindBlockTool(BaseTool): session_id=session_id, ) - # Enrich results with full block information + # Enrich results with block information blocks: list[BlockInfoSummary] = [] for result in results: block_id = result["content_id"] @@ -141,65 +141,12 @@ class FindBlockTool(BaseTool): ): continue - # Get input/output schemas - input_schema = {} - output_schema = {} - try: - input_schema = block.input_schema.jsonschema() - except Exception as e: - logger.debug( - "Failed to generate input schema for block %s: %s", - block_id, - e, - ) - try: - output_schema = block.output_schema.jsonschema() - except Exception as e: - logger.debug( - "Failed to generate output schema for block %s: %s", - block_id, - e, - ) - - # Get categories from block instance - categories = [] - if hasattr(block, "categories") and block.categories: - categories = [cat.value for cat in block.categories] - - # Extract required inputs for easier use - required_inputs: list[BlockInputFieldInfo] = [] - if input_schema: - properties = input_schema.get("properties", {}) - required_fields = set(input_schema.get("required", [])) - # Get credential field names to exclude from required inputs - credentials_fields = set( - block.input_schema.get_credentials_fields().keys() - ) - - for field_name, field_schema in properties.items(): - # Skip credential fields - they're handled separately - if field_name in credentials_fields: - continue - - required_inputs.append( - BlockInputFieldInfo( - name=field_name, - type=field_schema.get("type", "string"), - description=field_schema.get("description", ""), - required=field_name in required_fields, - default=field_schema.get("default"), - ) - ) - blocks.append( BlockInfoSummary( id=block_id, name=block.name, description=block.description or "", - categories=categories, - input_schema=input_schema, - output_schema=output_schema, - required_inputs=required_inputs, + categories=[c.value for c in block.categories], ) ) @@ -228,8 +175,7 @@ class FindBlockTool(BaseTool): return BlockListResponse( message=( f"Found {len(blocks)} block(s) matching '{query}'. " - "To execute a block, use run_block with the block's 'id' field " - "and provide 'input_data' matching the block's input_schema." + "To see a block's inputs/outputs and execute it, use run_block with the block's 'id' - providing no inputs." ), blocks=blocks, count=len(blocks), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py index d567a89bbe..44606f81c3 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py @@ -18,7 +18,13 @@ _TEST_USER_ID = "test-user-find-block" def make_mock_block( - block_id: str, name: str, block_type: BlockType, disabled: bool = False + block_id: str, + name: str, + block_type: BlockType, + disabled: bool = False, + input_schema: dict | None = None, + output_schema: dict | None = None, + credentials_fields: dict | None = None, ): """Create a mock block for testing.""" mock = MagicMock() @@ -28,10 +34,13 @@ def make_mock_block( mock.block_type = block_type mock.disabled = disabled mock.input_schema = MagicMock() - mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []} - mock.input_schema.get_credentials_fields.return_value = {} + mock.input_schema.jsonschema.return_value = input_schema or { + "properties": {}, + "required": [], + } + mock.input_schema.get_credentials_fields.return_value = credentials_fields or {} mock.output_schema = MagicMock() - mock.output_schema.jsonschema.return_value = {} + mock.output_schema.jsonschema.return_value = output_schema or {} mock.categories = [] return mock @@ -137,3 +146,241 @@ class TestFindBlockFiltering: assert isinstance(response, BlockListResponse) assert len(response.blocks) == 1 assert response.blocks[0].id == "normal-block-id" + + @pytest.mark.asyncio(loop_scope="session") + async def test_response_size_average_chars_per_block(self): + """Measure average chars per block in the serialized response.""" + session = make_session(user_id=_TEST_USER_ID) + + # Realistic block definitions modeled after real blocks + block_defs = [ + { + "id": "http-block-id", + "name": "Send Web Request", + "input_schema": { + "properties": { + "url": { + "type": "string", + "description": "The URL to send the request to", + }, + "method": { + "type": "string", + "description": "The HTTP method to use", + }, + "headers": { + "type": "object", + "description": "Headers to include in the request", + }, + "json_format": { + "type": "boolean", + "description": "If true, send the body as JSON", + }, + "body": { + "type": "object", + "description": "Form/JSON body payload", + }, + "credentials": { + "type": "object", + "description": "HTTP credentials", + }, + }, + "required": ["url", "method"], + }, + "output_schema": { + "properties": { + "response": { + "type": "object", + "description": "The response from the server", + }, + "client_error": { + "type": "object", + "description": "Errors on 4xx status codes", + }, + "server_error": { + "type": "object", + "description": "Errors on 5xx status codes", + }, + "error": { + "type": "string", + "description": "Errors for all other exceptions", + }, + }, + }, + "credentials_fields": {"credentials": True}, + }, + { + "id": "email-block-id", + "name": "Send Email", + "input_schema": { + "properties": { + "to_email": { + "type": "string", + "description": "Recipient email address", + }, + "subject": { + "type": "string", + "description": "Subject of the email", + }, + "body": { + "type": "string", + "description": "Body of the email", + }, + "config": { + "type": "object", + "description": "SMTP Config", + }, + "credentials": { + "type": "object", + "description": "SMTP credentials", + }, + }, + "required": ["to_email", "subject", "body", "credentials"], + }, + "output_schema": { + "properties": { + "status": { + "type": "string", + "description": "Status of the email sending operation", + }, + "error": { + "type": "string", + "description": "Error message if sending failed", + }, + }, + }, + "credentials_fields": {"credentials": True}, + }, + { + "id": "claude-code-block-id", + "name": "Claude Code", + "input_schema": { + "properties": { + "e2b_credentials": { + "type": "object", + "description": "API key for E2B platform", + }, + "anthropic_credentials": { + "type": "object", + "description": "API key for Anthropic", + }, + "prompt": { + "type": "string", + "description": "Task or instruction for Claude Code", + }, + "timeout": { + "type": "integer", + "description": "Sandbox timeout in seconds", + }, + "setup_commands": { + "type": "array", + "description": "Shell commands to run before execution", + }, + "working_directory": { + "type": "string", + "description": "Working directory for Claude Code", + }, + "session_id": { + "type": "string", + "description": "Session ID to resume a conversation", + }, + "sandbox_id": { + "type": "string", + "description": "Sandbox ID to reconnect to", + }, + "conversation_history": { + "type": "string", + "description": "Previous conversation history", + }, + "dispose_sandbox": { + "type": "boolean", + "description": "Whether to dispose sandbox after execution", + }, + }, + "required": [ + "e2b_credentials", + "anthropic_credentials", + "prompt", + ], + }, + "output_schema": { + "properties": { + "response": { + "type": "string", + "description": "Output from Claude Code execution", + }, + "files": { + "type": "array", + "description": "Files created/modified by Claude Code", + }, + "conversation_history": { + "type": "string", + "description": "Full conversation history", + }, + "session_id": { + "type": "string", + "description": "Session ID for this conversation", + }, + "sandbox_id": { + "type": "string", + "description": "ID of the sandbox instance", + }, + "error": { + "type": "string", + "description": "Error message if execution failed", + }, + }, + }, + "credentials_fields": { + "e2b_credentials": True, + "anthropic_credentials": True, + }, + }, + ] + + search_results = [ + {"content_id": d["id"], "score": 0.9 - i * 0.1} + for i, d in enumerate(block_defs) + ] + mock_blocks = { + d["id"]: make_mock_block( + block_id=d["id"], + name=d["name"], + block_type=BlockType.STANDARD, + input_schema=d["input_schema"], + output_schema=d["output_schema"], + credentials_fields=d["credentials_fields"], + ) + for d in block_defs + } + + with patch( + "backend.api.features.chat.tools.find_block.unified_hybrid_search", + new_callable=AsyncMock, + return_value=(search_results, len(search_results)), + ), patch( + "backend.api.features.chat.tools.find_block.get_block", + side_effect=lambda bid: mock_blocks.get(bid), + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(response, BlockListResponse) + assert response.count == len(block_defs) + + total_chars = len(response.model_dump_json()) + avg_chars = total_chars // response.count + + # Print for visibility in test output + print(f"\nTotal response size: {total_chars} chars") + print(f"Number of blocks: {response.count}") + print(f"Average chars per block: {avg_chars}") + + # The old response was ~90K for 10 blocks (~9K per block). + # Previous optimization reduced it to ~1.5K per block (no raw JSON schemas). + # Now with only id/name/description, we expect ~300 chars per block. + assert avg_chars < 500, ( + f"Average chars per block ({avg_chars}) exceeds 500. " + f"Total response: {total_chars} chars for {response.count} blocks." + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 69c8c6c684..b32f6ca2ce 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -25,6 +25,7 @@ class ResponseType(str, Enum): AGENT_SAVED = "agent_saved" CLARIFICATION_NEEDED = "clarification_needed" BLOCK_LIST = "block_list" + BLOCK_DETAILS = "block_details" BLOCK_OUTPUT = "block_output" DOC_SEARCH_RESULTS = "doc_search_results" DOC_PAGE = "doc_page" @@ -40,6 +41,15 @@ class ResponseType(str, Enum): OPERATION_IN_PROGRESS = "operation_in_progress" # Input validation INPUT_VALIDATION_ERROR = "input_validation_error" + # Web fetch + WEB_FETCH = "web_fetch" + # Code execution + BASH_EXEC = "bash_exec" + # Operation status check + OPERATION_STATUS = "operation_status" + # Feature request types + FEATURE_REQUEST_SEARCH = "feature_request_search" + FEATURE_REQUEST_CREATED = "feature_request_created" # Base response model @@ -335,11 +345,17 @@ class BlockInfoSummary(BaseModel): name: str description: str categories: list[str] - input_schema: dict[str, Any] - output_schema: dict[str, Any] + input_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block inputs", + ) + output_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block outputs", + ) required_inputs: list[BlockInputFieldInfo] = Field( default_factory=list, - description="List of required input fields for this block", + description="List of input fields for this block", ) @@ -352,10 +368,29 @@ class BlockListResponse(ToolResponseBase): query: str usage_hint: str = Field( default="To execute a block, call run_block with block_id set to the block's " - "'id' field and input_data containing the required fields from input_schema." + "'id' field and input_data containing the fields listed in required_inputs." ) +class BlockDetails(BaseModel): + """Detailed block information.""" + + id: str + name: str + description: str + inputs: dict[str, Any] = {} + outputs: dict[str, Any] = {} + credentials: list[CredentialsMetaInput] = [] + + +class BlockDetailsResponse(ToolResponseBase): + """Response for block details (first run_block attempt).""" + + type: ResponseType = ResponseType.BLOCK_DETAILS + block: BlockDetails + user_authenticated: bool = False + + class BlockOutputResponse(ToolResponseBase): """Response for run_block tool.""" @@ -421,3 +456,55 @@ class AsyncProcessingResponse(ToolResponseBase): status: str = "accepted" # Must be "accepted" for detection operation_id: str | None = None task_id: str | None = None + + +class WebFetchResponse(ToolResponseBase): + """Response for web_fetch tool.""" + + type: ResponseType = ResponseType.WEB_FETCH + url: str + status_code: int + content_type: str + content: str + truncated: bool = False + + +class BashExecResponse(ToolResponseBase): + """Response for bash_exec tool.""" + + type: ResponseType = ResponseType.BASH_EXEC + stdout: str + stderr: str + exit_code: int + timed_out: bool = False + + +# Feature request models +class FeatureRequestInfo(BaseModel): + """Information about a feature request issue.""" + + id: str + identifier: str + title: str + description: str | None = None + + +class FeatureRequestSearchResponse(ToolResponseBase): + """Response for search_feature_requests tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH + results: list[FeatureRequestInfo] + count: int + query: str + + +class FeatureRequestCreatedResponse(ToolResponseBase): + """Response for create_feature_request tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED + issue_id: str + issue_identifier: str + issue_title: str + issue_url: str + is_new_issue: bool # False if added to existing + customer_name: str diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 8c29820f8e..a55478326a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -23,8 +23,11 @@ from backend.util.exceptions import BlockError from .base import BaseTool from .helpers import get_inputs_from_schema from .models import ( + BlockDetails, + BlockDetailsResponse, BlockOutputResponse, ErrorResponse, + InputValidationErrorResponse, SetupInfo, SetupRequirementsResponse, ToolResponseBase, @@ -51,8 +54,8 @@ class RunBlockTool(BaseTool): "Execute a specific block with the provided input data. " "IMPORTANT: You MUST call find_block first to get the block's 'id' - " "do NOT guess or make up block IDs. " - "Use the 'id' from find_block results and provide input_data " - "matching the block's required_inputs." + "On first attempt (without input_data), returns detailed schema showing " + "required inputs and outputs. Then call again with proper input_data to execute." ) @property @@ -67,11 +70,19 @@ class RunBlockTool(BaseTool): "NEVER guess this - always get it from find_block first." ), }, + "block_name": { + "type": "string", + "description": ( + "The block's human-readable name from find_block results. " + "Used for display purposes in the UI." + ), + }, "input_data": { "type": "object", "description": ( - "Input values for the block. Use the 'required_inputs' field " - "from find_block to see what fields are needed." + "Input values for the block. " + "First call with empty {} to see the block's schema, " + "then call again with proper values to execute." ), }, }, @@ -156,6 +167,34 @@ class RunBlockTool(BaseTool): await self._resolve_block_credentials(user_id, block, input_data) ) + # Get block schemas for details/validation + try: + input_schema: dict[str, Any] = block.input_schema.jsonschema() + except Exception as e: + logger.warning( + "Failed to generate input schema for block %s: %s", + block_id, + e, + ) + return ErrorResponse( + message=f"Block '{block.name}' has an invalid input schema", + error=str(e), + session_id=session_id, + ) + try: + output_schema: dict[str, Any] = block.output_schema.jsonschema() + except Exception as e: + logger.warning( + "Failed to generate output schema for block %s: %s", + block_id, + e, + ) + return ErrorResponse( + message=f"Block '{block.name}' has an invalid output schema", + error=str(e), + session_id=session_id, + ) + if missing_credentials: # Return setup requirements response with missing credentials credentials_fields_info = block.input_schema.get_credentials_fields_info() @@ -188,6 +227,53 @@ class RunBlockTool(BaseTool): graph_version=None, ) + # Check if this is a first attempt (required inputs missing) + # Return block details so user can see what inputs are needed + credentials_fields = set(block.input_schema.get_credentials_fields().keys()) + required_keys = set(input_schema.get("required", [])) + required_non_credential_keys = required_keys - credentials_fields + provided_input_keys = set(input_data.keys()) - credentials_fields + + # Check for unknown input fields + valid_fields = ( + set(input_schema.get("properties", {}).keys()) - credentials_fields + ) + unrecognized_fields = provided_input_keys - valid_fields + if unrecognized_fields: + return InputValidationErrorResponse( + message=( + f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. " + f"Block was not executed. Please use the correct field names from the schema." + ), + session_id=session_id, + unrecognized_fields=sorted(unrecognized_fields), + inputs=input_schema, + ) + + # Show details when not all required non-credential inputs are provided + if not (required_non_credential_keys <= provided_input_keys): + # Get credentials info for the response + credentials_meta = [] + for field_name, cred_meta in matched_credentials.items(): + credentials_meta.append(cred_meta) + + return BlockDetailsResponse( + message=( + f"Block '{block.name}' details. " + "Provide input_data matching the inputs schema to execute the block." + ), + session_id=session_id, + block=BlockDetails( + id=block_id, + name=block.name, + description=block.description or "", + inputs=input_schema, + outputs=output_schema, + credentials=credentials_meta, + ), + user_authenticated=True, + ) + try: # Get or create user's workspace for CoPilot file operations workspace = await get_or_create_workspace(user_id) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py index aadc161155..55efc38479 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py @@ -1,10 +1,15 @@ -"""Tests for block execution guards in RunBlockTool.""" +"""Tests for block execution guards and input validation in RunBlockTool.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from backend.api.features.chat.tools.models import ErrorResponse +from backend.api.features.chat.tools.models import ( + BlockDetailsResponse, + BlockOutputResponse, + ErrorResponse, + InputValidationErrorResponse, +) from backend.api.features.chat.tools.run_block import RunBlockTool from backend.blocks._base import BlockType @@ -28,6 +33,39 @@ def make_mock_block( return mock +def make_mock_block_with_schema( + block_id: str, + name: str, + input_properties: dict, + required_fields: list[str], + output_properties: dict | None = None, +): + """Create a mock block with a defined input/output schema for validation tests.""" + mock = MagicMock() + mock.id = block_id + mock.name = name + mock.block_type = BlockType.STANDARD + mock.disabled = False + mock.description = f"Test block: {name}" + + input_schema = { + "properties": input_properties, + "required": required_fields, + } + mock.input_schema = MagicMock() + mock.input_schema.jsonschema.return_value = input_schema + mock.input_schema.get_credentials_fields_info.return_value = {} + mock.input_schema.get_credentials_fields.return_value = {} + + output_schema = { + "properties": output_properties or {"result": {"type": "string"}}, + } + mock.output_schema = MagicMock() + mock.output_schema.jsonschema.return_value = output_schema + + return mock + + class TestRunBlockFiltering: """Tests for block execution guards in RunBlockTool.""" @@ -104,3 +142,221 @@ class TestRunBlockFiltering: # (may be other errors like missing credentials, but not the exclusion guard) if isinstance(response, ErrorResponse): assert "cannot be run directly in CoPilot" not in response.message + + +class TestRunBlockInputValidation: + """Tests for input field validation in RunBlockTool. + + run_block rejects unknown input field names with InputValidationErrorResponse, + preventing silent failures where incorrect keys would be ignored and the block + would execute with default values instead of the caller's intended values. + """ + + @pytest.mark.asyncio(loop_scope="session") + async def test_unknown_input_fields_are_rejected(self): + """run_block rejects unknown input fields instead of silently ignoring them. + + Scenario: The AI Text Generator block has a field called 'model' (for LLM model + selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model' + instead. The block should reject the request and return the valid schema. + """ + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string", "description": "The prompt to send"}, + "model": { + "type": "string", + "description": "The LLM model to use", + "default": "gpt-4o-mini", + }, + "sys_prompt": { + "type": "string", + "description": "System prompt", + "default": "", + }, + }, + required_fields=["prompt"], + output_properties={"response": {"type": "string"}}, + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + # Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key) + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "prompt": "Write a haiku about coding", + "LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model' + }, + ) + + assert isinstance(response, InputValidationErrorResponse) + assert "LLM_Model" in response.unrecognized_fields + assert "Block was not executed" in response.message + assert "inputs" in response.model_dump() # valid schema included + + @pytest.mark.asyncio(loop_scope="session") + async def test_multiple_wrong_keys_are_all_reported(self): + """All unrecognized field names are reported in a single error response.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + "sys_prompt": {"type": "string", "default": ""}, + "retry": {"type": "integer", "default": 3}, + }, + required_fields=["prompt"], + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "prompt": "Hello", # correct + "llm_model": "claude-opus-4-6", # WRONG - should be 'model' + "system_prompt": "Be helpful", # WRONG - should be 'sys_prompt' + "retries": 5, # WRONG - should be 'retry' + }, + ) + + assert isinstance(response, InputValidationErrorResponse) + assert set(response.unrecognized_fields) == { + "llm_model", + "system_prompt", + "retries", + } + assert "Block was not executed" in response.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_unknown_fields_rejected_even_with_missing_required(self): + """Unknown fields are caught before the missing-required-fields check.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + }, + required_fields=["prompt"], + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + # 'prompt' is missing AND 'LLM_Model' is an unknown field + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing + }, + ) + + # Unknown fields are caught first + assert isinstance(response, InputValidationErrorResponse) + assert "LLM_Model" in response.unrecognized_fields + + @pytest.mark.asyncio(loop_scope="session") + async def test_correct_inputs_still_execute(self): + """Correct input field names pass validation and the block executes.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + }, + required_fields=["prompt"], + ) + + async def mock_execute(input_data, **kwargs): + yield "response", "Generated text" + + mock_block.execute = mock_execute + + with ( + patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ), + patch( + "backend.api.features.chat.tools.run_block.get_or_create_workspace", + new_callable=AsyncMock, + return_value=MagicMock(id="test-workspace-id"), + ), + ): + tool = RunBlockTool() + + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "prompt": "Write a haiku", + "model": "gpt-4o-mini", # correct field name + }, + ) + + assert isinstance(response, BlockOutputResponse) + assert response.success is True + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_required_fields_returns_details(self): + """Missing required fields returns BlockDetailsResponse with schema.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + }, + required_fields=["prompt"], + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + # Only provide valid optional field, missing required 'prompt' + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "model": "gpt-4o-mini", # valid but optional + }, + ) + + assert isinstance(response, BlockDetailsResponse) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py new file mode 100644 index 0000000000..beb326f909 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py @@ -0,0 +1,265 @@ +"""Sandbox execution utilities for code execution tools. + +Provides filesystem + network isolated command execution using **bubblewrap** +(``bwrap``): whitelist-only filesystem (only system dirs visible read-only), +writable workspace only, clean environment, network blocked. + +Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox` +and refuse to run if bubblewrap is not available. +""" + +import asyncio +import logging +import os +import platform +import shutil + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT = 30 +_MAX_TIMEOUT = 120 + + +# --------------------------------------------------------------------------- +# Sandbox capability detection (cached at first call) +# --------------------------------------------------------------------------- + +_BWRAP_AVAILABLE: bool | None = None + + +def has_full_sandbox() -> bool: + """Return True if bubblewrap is available (filesystem + network isolation). + + On non-Linux platforms (macOS), always returns False. + """ + global _BWRAP_AVAILABLE + if _BWRAP_AVAILABLE is None: + _BWRAP_AVAILABLE = ( + platform.system() == "Linux" and shutil.which("bwrap") is not None + ) + return _BWRAP_AVAILABLE + + +WORKSPACE_PREFIX = "/tmp/copilot-" + + +def make_session_path(session_id: str) -> str: + """Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`. + + Shared by both the SDK working-directory setup and the sandbox tools so + they always resolve to the same directory for a given session. + + Steps: + 1. Strip all characters except ``[A-Za-z0-9-]``. + 2. Construct ``/tmp/copilot-``. + 3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised + sanitizer) to prevent path traversal. + + Raises: + ValueError: If the resulting path escapes the prefix. + """ + import re + + safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id) + if not safe_id: + safe_id = "default" + path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}") + if not path.startswith(WORKSPACE_PREFIX): + raise ValueError(f"Session path escaped prefix: {path}") + return path + + +def get_workspace_dir(session_id: str) -> str: + """Get or create the workspace directory for a session. + + Uses :func:`make_session_path` — the same path the SDK uses — so that + bash_exec shares the workspace with the SDK file tools. + """ + workspace = make_session_path(session_id) + os.makedirs(workspace, exist_ok=True) + return workspace + + +# --------------------------------------------------------------------------- +# Bubblewrap command builder +# --------------------------------------------------------------------------- + +# System directories mounted read-only inside the sandbox. +# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible. +_SYSTEM_RO_BINDS = [ + "/usr", # binaries, libraries, Python interpreter + "/etc", # system config: ld.so, locale, passwd, alternatives +] + +# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems. +# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind +# can't create a symlink target, so we detect and use --symlink instead. +# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2. +_COMPAT_PATHS = [ + ("/bin", "usr/bin"), # -> /usr/bin on Debian 13 + ("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13 + ("/lib", "usr/lib"), # -> /usr/lib on Debian 13 + ("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter +] + +# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse. +# Applied via ulimit inside the sandbox before exec'ing the user command. +_RESOURCE_LIMITS = ( + "ulimit -u 64" # max 64 processes (prevents fork bombs) + " -v 524288" # 512 MB virtual memory + " -f 51200" # 50 MB max file size (1024-byte blocks) + " -n 256" # 256 open file descriptors + " 2>/dev/null" +) + + +def _build_bwrap_command( + command: list[str], cwd: str, env: dict[str, str] +) -> list[str]: + """Build a bubblewrap command with strict filesystem + network isolation. + + Security model: + - **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``, + ``/bin``, ``/lib``) are mounted read-only. Application code (``/app``), + home directories, ``/var``, ``/opt``, etc. are NOT accessible at all. + - **Writable workspace only**: the per-session workspace is the sole + writable path. + - **Clean environment**: ``--clearenv`` wipes all inherited env vars. + Only the explicitly-passed safe env vars are set inside the sandbox. + - **Network isolation**: ``--unshare-net`` blocks all network access. + - **Resource limits**: ulimit caps on processes (64), memory (512MB), + file size (50MB), and open FDs (256) to prevent fork bombs and abuse. + - **New session**: prevents terminal control escape. + - **Die with parent**: prevents orphaned sandbox processes. + """ + cmd = [ + "bwrap", + # Create a new user namespace so bwrap can set up sandboxing + # inside unprivileged Docker containers (no CAP_SYS_ADMIN needed). + "--unshare-user", + # Wipe all inherited environment variables (API keys, secrets, etc.) + "--clearenv", + ] + + # Set only the safe env vars inside the sandbox + for key, value in env.items(): + cmd.extend(["--setenv", key, value]) + + # System directories: read-only + for path in _SYSTEM_RO_BINDS: + cmd.extend(["--ro-bind", path, path]) + + # Compat paths: use --symlink when host path is a symlink (Debian 13), + # --ro-bind when it's a real directory (older distros). + for path, symlink_target in _COMPAT_PATHS: + if os.path.islink(path): + cmd.extend(["--symlink", symlink_target, path]) + elif os.path.exists(path): + cmd.extend(["--ro-bind", path, path]) + + # Wrap the user command with resource limits: + # sh -c 'ulimit ...; exec "$@"' -- + # `exec "$@"` replaces the shell so there's no extra process overhead, + # and properly handles arguments with spaces. + limited_command = [ + "sh", + "-c", + f'{_RESOURCE_LIMITS}; exec "$@"', + "--", + *command, + ] + + cmd.extend( + [ + # Fresh virtual filesystems + "--dev", + "/dev", + "--proc", + "/proc", + "--tmpfs", + "/tmp", + # Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs. + # (workspace lives under /tmp/copilot-) + "--bind", + cwd, + cwd, + # Isolation + "--unshare-net", + "--die-with-parent", + "--new-session", + "--chdir", + cwd, + "--", + *limited_command, + ] + ) + + return cmd + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def run_sandboxed( + command: list[str], + cwd: str, + timeout: int = _DEFAULT_TIMEOUT, + env: dict[str, str] | None = None, +) -> tuple[str, str, int, bool]: + """Run a command inside a bubblewrap sandbox. + + Callers **must** check :func:`has_full_sandbox` before calling this + function. If bubblewrap is not available, this function raises + :class:`RuntimeError` rather than running unsandboxed. + + Returns: + (stdout, stderr, exit_code, timed_out) + """ + if not has_full_sandbox(): + raise RuntimeError( + "run_sandboxed() requires bubblewrap but bwrap is not available. " + "Callers must check has_full_sandbox() before calling this function." + ) + + timeout = min(max(timeout, 1), _MAX_TIMEOUT) + + safe_env = { + "PATH": "/usr/local/bin:/usr/bin:/bin", + "HOME": cwd, + "TMPDIR": cwd, + "LANG": "en_US.UTF-8", + "PYTHONDONTWRITEBYTECODE": "1", + "PYTHONIOENCODING": "utf-8", + } + if env: + safe_env.update(env) + + full_command = _build_bwrap_command(command, cwd, safe_env) + + try: + proc = await asyncio.create_subprocess_exec( + *full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=safe_env, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + stdout = stdout_bytes.decode("utf-8", errors="replace") + stderr = stderr_bytes.decode("utf-8", errors="replace") + return stdout, stderr, proc.returncode or 0, False + except asyncio.TimeoutError: + proc.kill() + await proc.communicate() + return "", f"Execution timed out after {timeout}s", -1, True + + except RuntimeError: + raise + except Exception as e: + return "", f"Sandbox error: {e}", -1, False diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py b/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py new file mode 100644 index 0000000000..fbab0b723d --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py @@ -0,0 +1,153 @@ +"""Tests for BlockDetailsResponse in RunBlockTool.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.api.features.chat.tools.models import BlockDetailsResponse +from backend.api.features.chat.tools.run_block import RunBlockTool +from backend.blocks._base import BlockType +from backend.data.model import CredentialsMetaInput +from backend.integrations.providers import ProviderName + +from ._test_data import make_session + +_TEST_USER_ID = "test-user-run-block-details" + + +def make_mock_block_with_inputs( + block_id: str, name: str, description: str = "Test description" +): + """Create a mock block with input/output schemas for testing.""" + mock = MagicMock() + mock.id = block_id + mock.name = name + mock.description = description + mock.block_type = BlockType.STANDARD + mock.disabled = False + + # Input schema with non-credential fields + mock.input_schema = MagicMock() + mock.input_schema.jsonschema.return_value = { + "properties": { + "url": {"type": "string", "description": "URL to fetch"}, + "method": {"type": "string", "description": "HTTP method"}, + }, + "required": ["url"], + } + mock.input_schema.get_credentials_fields.return_value = {} + mock.input_schema.get_credentials_fields_info.return_value = {} + + # Output schema + mock.output_schema = MagicMock() + mock.output_schema.jsonschema.return_value = { + "properties": { + "response": {"type": "object", "description": "HTTP response"}, + "error": {"type": "string", "description": "Error message"}, + } + } + + return mock + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_block_returns_details_when_no_input_provided(): + """When run_block is called without input_data, it should return BlockDetailsResponse.""" + session = make_session(user_id=_TEST_USER_ID) + + # Create a block with inputs + http_block = make_mock_block_with_inputs( + "http-block-id", "HTTP Request", "Send HTTP requests" + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=http_block, + ): + # Mock credentials check to return no missing credentials + with patch.object( + RunBlockTool, + "_resolve_block_credentials", + new_callable=AsyncMock, + return_value=({}, []), # (matched_credentials, missing_credentials) + ): + tool = RunBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="http-block-id", + input_data={}, # Empty input data + ) + + # Should return BlockDetailsResponse showing the schema + assert isinstance(response, BlockDetailsResponse) + assert response.block.id == "http-block-id" + assert response.block.name == "HTTP Request" + assert response.block.description == "Send HTTP requests" + assert "url" in response.block.inputs["properties"] + assert "method" in response.block.inputs["properties"] + assert "response" in response.block.outputs["properties"] + assert response.user_authenticated is True + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_block_returns_details_when_only_credentials_provided(): + """When only credentials are provided (no actual input), should return details.""" + session = make_session(user_id=_TEST_USER_ID) + + # Create a block with both credential and non-credential inputs + mock = MagicMock() + mock.id = "api-block-id" + mock.name = "API Call" + mock.description = "Make API calls" + mock.block_type = BlockType.STANDARD + mock.disabled = False + + mock.input_schema = MagicMock() + mock.input_schema.jsonschema.return_value = { + "properties": { + "credentials": {"type": "object", "description": "API credentials"}, + "endpoint": {"type": "string", "description": "API endpoint"}, + }, + "required": ["credentials", "endpoint"], + } + mock.input_schema.get_credentials_fields.return_value = {"credentials": True} + mock.input_schema.get_credentials_fields_info.return_value = {} + + mock.output_schema = MagicMock() + mock.output_schema.jsonschema.return_value = { + "properties": {"result": {"type": "object"}} + } + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock, + ): + with patch.object( + RunBlockTool, + "_resolve_block_credentials", + new_callable=AsyncMock, + return_value=( + { + "credentials": CredentialsMetaInput( + id="cred-id", + provider=ProviderName("test_provider"), + type="api_key", + title="Test Credential", + ) + }, + [], + ), + ): + tool = RunBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="api-block-id", + input_data={"credentials": {"some": "cred"}}, # Only credential + ) + + # Should return details because no non-credential inputs provided + assert isinstance(response, BlockDetailsResponse) + assert response.block.id == "api-block-id" + assert response.block.name == "API Call" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index 80a842bf36..3b2168d09e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -15,6 +15,7 @@ from backend.data.model import ( OAuth2Credentials, ) from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName from backend.util.exceptions import NotFoundError logger = logging.getLogger(__name__) @@ -359,7 +360,7 @@ async def match_user_credentials_to_graph( _, _, ) in aggregated_creds.items(): - # Find first matching credential by provider, type, and scopes + # Find first matching credential by provider, type, scopes, and host/URL matching_cred = next( ( cred @@ -374,6 +375,10 @@ async def match_user_credentials_to_graph( cred.type != "host_scoped" or _credential_is_for_host(cred, credential_requirements) ) + and ( + cred.provider != ProviderName.MCP + or _credential_is_for_mcp_server(cred, credential_requirements) + ) ), None, ) @@ -444,6 +449,22 @@ def _credential_is_for_host( return credential.matches_url(list(requirements.discriminator_values)[0]) +def _credential_is_for_mcp_server( + credential: Credentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if an MCP OAuth credential matches the required server URL.""" + if not requirements.discriminator_values: + return True + + server_url = ( + credential.metadata.get("mcp_server_url") + if isinstance(credential, OAuth2Credentials) + else None + ) + return server_url in requirements.discriminator_values if server_url else False + + async def check_user_has_required_credentials( user_id: str, required_credentials: list[CredentialsMetaInput], diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py b/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py new file mode 100644 index 0000000000..fed7cc11fa --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py @@ -0,0 +1,151 @@ +"""Web fetch tool — safely retrieve public web page content.""" + +import logging +from typing import Any + +import aiohttp +import html2text + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + ToolResponseBase, + WebFetchResponse, +) +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + +# Limits +_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap +_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15) + +# Content types we'll read as text +_TEXT_CONTENT_TYPES = { + "text/html", + "text/plain", + "text/xml", + "text/csv", + "text/markdown", + "application/json", + "application/xml", + "application/xhtml+xml", + "application/rss+xml", + "application/atom+xml", +} + + +def _is_text_content(content_type: str) -> bool: + base = content_type.split(";")[0].strip().lower() + return base in _TEXT_CONTENT_TYPES or base.startswith("text/") + + +def _html_to_text(html: str) -> str: + h = html2text.HTML2Text() + h.ignore_links = False + h.ignore_images = True + h.body_width = 0 + return h.handle(html) + + +class WebFetchTool(BaseTool): + """Safely fetch content from a public URL using SSRF-protected HTTP.""" + + @property + def name(self) -> str: + return "web_fetch" + + @property + def description(self) -> str: + return ( + "Fetch the content of a public web page by URL. " + "Returns readable text extracted from HTML by default. " + "Useful for reading documentation, articles, and API responses. " + "Only supports HTTP/HTTPS GET requests to public URLs " + "(private/internal network addresses are blocked)." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The public HTTP/HTTPS URL to fetch.", + }, + "extract_text": { + "type": "boolean", + "description": ( + "If true (default), extract readable text from HTML. " + "If false, return raw content." + ), + "default": True, + }, + }, + "required": ["url"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + url: str = (kwargs.get("url") or "").strip() + extract_text: bool = kwargs.get("extract_text", True) + session_id = session.session_id if session else None + + if not url: + return ErrorResponse( + message="Please provide a URL to fetch.", + error="missing_url", + session_id=session_id, + ) + + try: + client = Requests(raise_for_status=False, retry_max_attempts=1) + response = await client.get(url, timeout=_REQUEST_TIMEOUT) + except ValueError as e: + # validate_url raises ValueError for SSRF / blocked IPs + return ErrorResponse( + message=f"URL blocked: {e}", + error="url_blocked", + session_id=session_id, + ) + except Exception as e: + logger.warning(f"[web_fetch] Request failed for {url}: {e}") + return ErrorResponse( + message=f"Failed to fetch URL: {e}", + error="fetch_failed", + session_id=session_id, + ) + + content_type = response.headers.get("content-type", "") + if not _is_text_content(content_type): + return ErrorResponse( + message=f"Non-text content type: {content_type.split(';')[0]}", + error="unsupported_content_type", + session_id=session_id, + ) + + raw = response.content[:_MAX_CONTENT_BYTES] + text = raw.decode("utf-8", errors="replace") + + if extract_text and "html" in content_type.lower(): + text = _html_to_text(text) + + return WebFetchResponse( + message=f"Fetched {url}", + url=response.url, + status_code=response.status, + content_type=content_type.split(";")[0].strip(), + content=text, + truncated=False, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py index 03532c8fee..f37d2c80e0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py @@ -88,7 +88,9 @@ class ListWorkspaceFilesTool(BaseTool): @property def description(self) -> str: return ( - "List files in the user's workspace. " + "List files in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read/Glob tools instead. " "Returns file names, paths, sizes, and metadata. " "Optionally filter by path prefix." ) @@ -204,7 +206,9 @@ class ReadWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Read a file from the user's workspace. " + "Read a file from the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read tool instead. " "Specify either file_id or path to identify the file. " "For small text files, returns content directly. " "For large or binary files, returns metadata and a download URL. " @@ -378,7 +382,9 @@ class WriteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Write or create a file in the user's workspace. " + "Write or create a file in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Write tool instead. " "Provide the content as a base64-encoded string. " f"Maximum file size is {Config().max_file_size_mb}MB. " "Files are saved to the current session's folder by default. " @@ -523,7 +529,7 @@ class DeleteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Delete a file from the user's workspace. " + "Delete a file from the user's persistent workspace (cloud storage). " "Specify either file_id or path to identify the file. " "Paths are scoped to the current session by default. " "Use /sessions//... for cross-session access." diff --git a/autogpt_platform/backend/backend/api/features/integrations/router.py b/autogpt_platform/backend/backend/api/features/integrations/router.py index 00500dc8a8..4eacf83e71 100644 --- a/autogpt_platform/backend/backend/api/features/integrations/router.py +++ b/autogpt_platform/backend/backend/api/features/integrations/router.py @@ -1,7 +1,7 @@ import asyncio import logging from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Annotated, List, Literal +from typing import TYPE_CHECKING, Annotated, Any, List, Literal from autogpt_libs.auth import get_user_id from fastapi import ( @@ -14,7 +14,7 @@ from fastapi import ( Security, status, ) -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, Field, SecretStr, model_validator from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from backend.api.features.library.db import set_preset_webhook, update_preset @@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step from backend.data.user import get_user_integrations from backend.executor.utils import add_graph_execution from backend.integrations.ayrshare import AyrshareClient, SocialPlatform -from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.credentials_store import provider_matches +from backend.integrations.creds_manager import ( + IntegrationCredentialsManager, + create_mcp_oauth_handler, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.integrations.webhooks import get_webhook_manager @@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel): scopes: list[str] | None username: str | None host: str | None = Field( - default=None, description="Host pattern for host-scoped credentials" + default=None, + description="Host pattern for host-scoped or MCP server URL for MCP credentials", ) + @model_validator(mode="before") + @classmethod + def _normalize_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.""" + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + + @staticmethod + def get_host(cred: Credentials) -> str | None: + """Extract host from credential: HostScoped host or MCP server URL.""" + if isinstance(cred, HostScopedCredentials): + return cred.host + if isinstance(cred, OAuth2Credentials) and cred.provider in ( + ProviderName.MCP, + ProviderName.MCP.value, + "ProviderName.MCP", + ): + return (cred.metadata or {}).get("mcp_server_url") + return None + @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") async def callback( @@ -179,9 +211,7 @@ async def callback( title=credentials.title, scopes=credentials.scopes, username=credentials.username, - host=( - credentials.host if isinstance(credentials, HostScopedCredentials) else None - ), + host=(CredentialsMetaResponse.get_host(credentials)), ) @@ -199,7 +229,7 @@ async def list_credentials( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -222,7 +252,7 @@ async def list_credentials_by_provider( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -322,7 +352,11 @@ async def delete_credentials( tokens_revoked = None if isinstance(creds, OAuth2Credentials): - handler = _get_provider_oauth_handler(request, provider) + if provider_matches(provider.value, ProviderName.MCP.value): + # MCP uses dynamic per-server OAuth — create handler from metadata + handler = create_mcp_oauth_handler(creds) + else: + handler = _get_provider_oauth_handler(request, provider) tokens_revoked = await handler.revoke_tokens(creds) return CredentialsDeletionResponse(revoked=tokens_revoked) diff --git a/autogpt_platform/backend/backend/api/features/mcp/__init__.py b/autogpt_platform/backend/backend/api/features/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/api/features/mcp/routes.py b/autogpt_platform/backend/backend/api/features/mcp/routes.py new file mode 100644 index 0000000000..f8d311f372 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/routes.py @@ -0,0 +1,404 @@ +""" +MCP (Model Context Protocol) API routes. + +Provides endpoints for MCP tool discovery and OAuth authentication so the +frontend can list available tools on an MCP server before placing a block. +""" + +import logging +from typing import Annotated, Any +from urllib.parse import urlparse + +import fastapi +from autogpt_libs.auth import get_user_id +from fastapi import Security +from pydantic import BaseModel, Field + +from backend.api.features.integrations.router import CredentialsMetaResponse +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials +from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName +from backend.util.request import HTTPClientError, Requests +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + +settings = Settings() +router = fastapi.APIRouter(tags=["mcp"]) +creds_manager = IntegrationCredentialsManager() + + +# ====================== Tool Discovery ====================== # + + +class DiscoverToolsRequest(BaseModel): + """Request to discover tools on an MCP server.""" + + server_url: str = Field(description="URL of the MCP server") + auth_token: str | None = Field( + default=None, + description="Optional Bearer token for authenticated MCP servers", + ) + + +class MCPToolResponse(BaseModel): + """A single MCP tool returned by discovery.""" + + name: str + description: str + input_schema: dict[str, Any] + + +class DiscoverToolsResponse(BaseModel): + """Response containing the list of tools available on an MCP server.""" + + tools: list[MCPToolResponse] + server_name: str | None = None + protocol_version: str | None = None + + +@router.post( + "/discover-tools", + summary="Discover available tools on an MCP server", + response_model=DiscoverToolsResponse, +) +async def discover_tools( + request: DiscoverToolsRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> DiscoverToolsResponse: + """ + Connect to an MCP server and return its available tools. + + If the user has a stored MCP credential for this server URL, it will be + used automatically — no need to pass an explicit auth token. + """ + auth_token = request.auth_token + + # Auto-use stored MCP credential when no explicit token is provided. + if not auth_token: + mcp_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + # Find the freshest credential for this server URL + best_cred: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == request.server_url + ): + if best_cred is None or ( + (cred.access_token_expires_at or 0) + > (best_cred.access_token_expires_at or 0) + ): + best_cred = cred + if best_cred: + # Refresh the token if expired before using it + best_cred = await creds_manager.refresh_if_needed(user_id, best_cred) + logger.info( + f"Using MCP credential {best_cred.id} for {request.server_url}, " + f"expires_at={best_cred.access_token_expires_at}" + ) + auth_token = best_cred.access_token.get_secret_value() + + client = MCPClient(request.server_url, auth_token=auth_token) + + try: + init_result = await client.initialize() + tools = await client.list_tools() + except HTTPClientError as e: + if e.status_code in (401, 403): + raise fastapi.HTTPException( + status_code=401, + detail="This MCP server requires authentication. " + "Please provide a valid auth token.", + ) + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except MCPClientError as e: + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except Exception as e: + raise fastapi.HTTPException( + status_code=502, + detail=f"Failed to connect to MCP server: {e}", + ) + + return DiscoverToolsResponse( + tools=[ + MCPToolResponse( + name=t.name, + description=t.description, + input_schema=t.input_schema, + ) + for t in tools + ], + server_name=( + init_result.get("serverInfo", {}).get("name") + or urlparse(request.server_url).hostname + or "MCP" + ), + protocol_version=init_result.get("protocolVersion"), + ) + + +# ======================== OAuth Flow ======================== # + + +class MCPOAuthLoginRequest(BaseModel): + """Request to start an OAuth flow for an MCP server.""" + + server_url: str = Field(description="URL of the MCP server that requires OAuth") + + +class MCPOAuthLoginResponse(BaseModel): + """Response with the OAuth login URL for the user to authenticate.""" + + login_url: str + state_token: str + + +@router.post( + "/oauth/login", + summary="Initiate OAuth login for an MCP server", +) +async def mcp_oauth_login( + request: MCPOAuthLoginRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> MCPOAuthLoginResponse: + """ + Discover OAuth metadata from the MCP server and return a login URL. + + 1. Discovers the protected-resource metadata (RFC 9728) + 2. Fetches the authorization server metadata (RFC 8414) + 3. Performs Dynamic Client Registration (RFC 7591) if available + 4. Returns the authorization URL for the frontend to open in a popup + """ + client = MCPClient(request.server_url) + + # Step 1: Discover protected-resource metadata (RFC 9728) + protected_resource = await client.discover_auth() + + metadata: dict[str, Any] | None = None + + if protected_resource and protected_resource.get("authorization_servers"): + auth_server_url = protected_resource["authorization_servers"][0] + resource_url = protected_resource.get("resource", request.server_url) + + # Step 2a: Discover auth-server metadata (RFC 8414) + metadata = await client.discover_auth_server_metadata(auth_server_url) + else: + # Fallback: Some MCP servers (e.g. Linear) are their own auth server + # and serve OAuth metadata directly without protected-resource metadata. + # Don't assume a resource_url — omitting it lets the auth server choose + # the correct audience for the token (RFC 8707 resource is optional). + resource_url = None + metadata = await client.discover_auth_server_metadata(request.server_url) + + if ( + not metadata + or "authorization_endpoint" not in metadata + or "token_endpoint" not in metadata + ): + raise fastapi.HTTPException( + status_code=400, + detail="This MCP server does not advertise OAuth support. " + "You may need to provide an auth token manually.", + ) + + authorize_url = metadata["authorization_endpoint"] + token_url = metadata["token_endpoint"] + registration_endpoint = metadata.get("registration_endpoint") + revoke_url = metadata.get("revocation_endpoint") + + # Step 3: Dynamic Client Registration (RFC 7591) if available + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + client_id = "" + client_secret = "" + if registration_endpoint: + reg_result = await _register_mcp_client( + registration_endpoint, redirect_uri, request.server_url + ) + if reg_result: + client_id = reg_result.get("client_id", "") + client_secret = reg_result.get("client_secret", "") + + if not client_id: + client_id = "autogpt-platform" + + # Step 4: Store state token with OAuth metadata for the callback + scopes = (protected_resource or {}).get("scopes_supported") or metadata.get( + "scopes_supported", [] + ) + state_token, code_challenge = await creds_manager.store.store_state_token( + user_id, + ProviderName.MCP.value, + scopes, + state_metadata={ + "authorize_url": authorize_url, + "token_url": token_url, + "revoke_url": revoke_url, + "resource_url": resource_url, + "server_url": request.server_url, + "client_id": client_id, + "client_secret": client_secret, + }, + ) + + # Step 5: Build and return the login URL + handler = MCPOAuthHandler( + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + authorize_url=authorize_url, + token_url=token_url, + resource_url=resource_url, + ) + login_url = handler.get_login_url( + scopes, state_token, code_challenge=code_challenge + ) + + return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token) + + +class MCPOAuthCallbackRequest(BaseModel): + """Request to exchange an OAuth code for tokens.""" + + code: str = Field(description="Authorization code from OAuth callback") + state_token: str = Field(description="State token for CSRF verification") + + +class MCPOAuthCallbackResponse(BaseModel): + """Response after successfully storing OAuth credentials.""" + + credential_id: str + + +@router.post( + "/oauth/callback", + summary="Exchange OAuth code for MCP tokens", +) +async def mcp_oauth_callback( + request: MCPOAuthCallbackRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> CredentialsMetaResponse: + """ + Exchange the authorization code for tokens and store the credential. + + The frontend calls this after receiving the OAuth code from the popup. + On success, subsequent ``/discover-tools`` calls for the same server URL + will automatically use the stored credential. + """ + valid_state = await creds_manager.store.verify_state_token( + user_id, request.state_token, ProviderName.MCP.value + ) + if not valid_state: + raise fastapi.HTTPException( + status_code=400, + detail="Invalid or expired state token.", + ) + + meta = valid_state.state_metadata + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + handler = MCPOAuthHandler( + client_id=meta["client_id"], + client_secret=meta.get("client_secret", ""), + redirect_uri=redirect_uri, + authorize_url=meta["authorize_url"], + token_url=meta["token_url"], + revoke_url=meta.get("revoke_url"), + resource_url=meta.get("resource_url"), + ) + + try: + credentials = await handler.exchange_code_for_tokens( + request.code, valid_state.scopes, valid_state.code_verifier + ) + except Exception as e: + raise fastapi.HTTPException( + status_code=400, + detail=f"OAuth token exchange failed: {e}", + ) + + # Enrich credential metadata for future lookup and token refresh + if credentials.metadata is None: + credentials.metadata = {} + credentials.metadata["mcp_server_url"] = meta["server_url"] + credentials.metadata["mcp_client_id"] = meta["client_id"] + credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "") + credentials.metadata["mcp_token_url"] = meta["token_url"] + credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "") + + hostname = urlparse(meta["server_url"]).hostname or meta["server_url"] + credentials.title = f"MCP: {hostname}" + + # Remove old MCP credentials for the same server to prevent stale token buildup. + try: + old_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + for old in old_creds: + if ( + isinstance(old, OAuth2Credentials) + and (old.metadata or {}).get("mcp_server_url") == meta["server_url"] + ): + await creds_manager.store.delete_creds_by_id(user_id, old.id) + logger.info( + f"Removed old MCP credential {old.id} for {meta['server_url']}" + ) + except Exception: + logger.debug("Could not clean up old MCP credentials", exc_info=True) + + await creds_manager.create(user_id, credentials) + + return CredentialsMetaResponse( + id=credentials.id, + provider=credentials.provider, + type=credentials.type, + title=credentials.title, + scopes=credentials.scopes, + username=credentials.username, + host=credentials.metadata.get("mcp_server_url"), + ) + + +# ======================== Helpers ======================== # + + +async def _register_mcp_client( + registration_endpoint: str, + redirect_uri: str, + server_url: str, +) -> dict[str, Any] | None: + """Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server.""" + try: + response = await Requests(raise_for_status=True).post( + registration_endpoint, + json={ + "client_name": "AutoGPT Platform", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + }, + ) + data = response.json() + if isinstance(data, dict) and "client_id" in data: + return data + return None + except Exception as e: + logger.warning(f"Dynamic client registration failed for {server_url}: {e}") + return None diff --git a/autogpt_platform/backend/backend/api/features/mcp/test_routes.py b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py new file mode 100644 index 0000000000..e86b9f4865 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py @@ -0,0 +1,436 @@ +"""Tests for MCP API routes. + +Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient +to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop. +""" + +from unittest.mock import AsyncMock, patch + +import fastapi +import httpx +import pytest +import pytest_asyncio +from autogpt_libs.auth import get_user_id + +from backend.api.features.mcp.routes import router +from backend.blocks.mcp.client import MCPClientError, MCPTool +from backend.util.request import HTTPClientError + +app = fastapi.FastAPI() +app.include_router(router) +app.dependency_overrides[get_user_id] = lambda: "test-user-id" + + +@pytest_asyncio.fixture(scope="module") +async def client(): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestDiscoverTools: + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_success(self, client): + mock_tools = [ + MCPTool( + name="get_weather", + description="Get weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + MCPTool( + name="add_numbers", + description="Add two numbers", + input_schema={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + }, + ), + ] + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={ + "protocolVersion": "2025-03-26", + "serverInfo": {"name": "test-server"}, + } + ) + instance.list_tools = AsyncMock(return_value=mock_tools) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["tools"]) == 2 + assert data["tools"][0]["name"] == "get_weather" + assert data["tools"][1]["name"] == "add_numbers" + assert data["server_name"] == "test-server" + assert data["protocol_version"] == "2025-03-26" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_with_auth_token(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={ + "server_url": "https://mcp.example.com/mcp", + "auth_token": "my-secret-token", + }, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="my-secret-token", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auto_uses_stored_credential(self, client): + """When no explicit token is given, stored MCP credentials are used.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + stored_cred = OAuth2Credentials( + provider="mcp", + title="MCP: example.com", + access_token=SecretStr("stored-token-123"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={"mcp_server_url": "https://mcp.example.com/mcp"}, + ) + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred]) + mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="stored-token-123", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_mcp_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=MCPClientError("Connection refused") + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://bad-server.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Connection refused" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_generic_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock(side_effect=Exception("Network timeout")) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://timeout.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Failed to connect" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auth_required(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_forbidden(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_missing_url(self, client): + response = await client.post("/discover-tools", json={}) + assert response.status_code == 422 + + +class TestOAuthLogin: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_success(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch( + "backend.api.features.mcp.routes._register_mcp_client" + ) as mock_register, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.sentry.io"], + "resource": "https://mcp.sentry.dev/mcp", + "scopes_supported": ["openid"], + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.sentry.io/authorize", + "token_endpoint": "https://auth.sentry.io/token", + "registration_endpoint": "https://auth.sentry.io/register", + } + ) + mock_register.return_value = { + "client_id": "registered-client-id", + "client_secret": "registered-secret", + } + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-token-123", "code-challenge-abc") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.sentry.dev/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "login_url" in data + assert data["state_token"] == "state-token-123" + assert "auth.sentry.io/authorize" in data["login_url"] + assert "registered-client-id" in data["login_url"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_no_oauth_support(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.discover_auth = AsyncMock(return_value=None) + instance.discover_auth_server_metadata = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/login", + json={"server_url": "https://simple-server.example.com/mcp"}, + ) + + assert response.status_code == 400 + assert "does not advertise OAuth" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_fallback_to_public_client(self, client): + """When DCR is unavailable, falls back to default public client ID.""" + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + # No registration_endpoint + } + ) + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-abc", "challenge-xyz") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "autogpt-platform" in data["login_url"] + + +class TestOAuthCallback: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_success(self, client): + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + mock_creds = OAuth2Credentials( + provider="mcp", + title=None, + access_token=SecretStr("access-token-xyz"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={ + "mcp_token_url": "https://auth.sentry.io/token", + "mcp_resource_url": "https://mcp.sentry.dev/mcp", + }, + ) + + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + + # Mock state verification + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.sentry.io/authorize", + "token_url": "https://auth.sentry.io/token", + "client_id": "test-client-id", + "client_secret": "test-secret", + "server_url": "https://mcp.sentry.dev/mcp", + } + mock_state.scopes = ["openid"] + mock_state.code_verifier = "verifier-123" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + mock_cm.create = AsyncMock() + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + return_value=mock_creds + ) + + # Mock old credential cleanup + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code-abc", "state_token": "state-token-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["provider"] == "mcp" + assert data["type"] == "oauth2" + mock_cm.create.assert_called_once() + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_invalid_state(self, client): + with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm: + mock_cm.store.verify_state_token = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code", "state_token": "bad-state"}, + ) + + assert response.status_code == 400 + assert "Invalid or expired" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_token_exchange_fails(self, client): + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "server_url": "https://mcp.example.com/mcp", + } + mock_state.scopes = [] + mock_state.code_verifier = "v" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + side_effect=RuntimeError("Token exchange failed") + ) + + response = await client.post( + "/oauth/callback", + json={"code": "bad-code", "state_token": "state"}, + ) + + assert response.status_code == 400 + assert "token exchange failed" in response.json()["detail"].lower() diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index 0eef76193e..aed348755b 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes import backend.api.features.library.db import backend.api.features.library.model import backend.api.features.library.routes +import backend.api.features.mcp.routes as mcp_routes import backend.api.features.oauth import backend.api.features.otto.routes import backend.api.features.postmark.postmark @@ -343,6 +344,11 @@ app.include_router( tags=["workspace"], prefix="/api/workspace", ) +app.include_router( + mcp_routes.router, + tags=["v2", "mcp"], + prefix="/api/mcp", +) app.include_router( backend.api.features.oauth.router, tags=["oauth"], diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 0ba4daec40..632c5e43b9 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -64,6 +64,7 @@ class BlockType(Enum): AI = "AI" AYRSHARE = "Ayrshare" HUMAN_IN_THE_LOOP = "Human In The Loop" + MCP_TOOL = "MCP Tool" class BlockCategory(Enum): diff --git a/autogpt_platform/backend/backend/blocks/basic.py b/autogpt_platform/backend/backend/blocks/basic.py index f129d2707b..5fdcfb6d82 100644 --- a/autogpt_platform/backend/backend/blocks/basic.py +++ b/autogpt_platform/backend/backend/blocks/basic.py @@ -126,6 +126,7 @@ class PrintToConsoleBlock(Block): output_schema=PrintToConsoleBlock.Output, test_input={"text": "Hello, World!"}, is_sensitive_action=True, + disabled=True, # Disabled per Nick Tindle's request (OPEN-3000) test_output=[ ("output", "Hello, World!"), ("status", "printed"), diff --git a/autogpt_platform/backend/backend/blocks/data_manipulation.py b/autogpt_platform/backend/backend/blocks/data_manipulation.py index a8f25ecb18..fe878acfa9 100644 --- a/autogpt_platform/backend/backend/blocks/data_manipulation.py +++ b/autogpt_platform/backend/backend/blocks/data_manipulation.py @@ -682,17 +682,219 @@ class ListIsEmptyBlock(Block): yield "is_empty", len(input_data.list) == 0 +# ============================================================================= +# List Concatenation Helpers +# ============================================================================= + + +def _validate_list_input(item: Any, index: int) -> str | None: + """Validate that an item is a list. Returns error message or None.""" + if item is None: + return None # None is acceptable, will be skipped + if not isinstance(item, list): + return ( + f"Invalid input at index {index}: expected a list, " + f"got {type(item).__name__}. " + f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])." + ) + return None + + +def _validate_all_lists(lists: List[Any]) -> str | None: + """Validate that all items in a sequence are lists. Returns first error or None.""" + for idx, item in enumerate(lists): + error = _validate_list_input(item, idx) + if error is not None and item is not None: + return error + return None + + +def _concatenate_lists_simple(lists: List[List[Any]]) -> List[Any]: + """Concatenate a sequence of lists into a single list, skipping None values.""" + result: List[Any] = [] + for lst in lists: + if lst is None: + continue + result.extend(lst) + return result + + +def _flatten_nested_list(nested: List[Any], max_depth: int = -1) -> List[Any]: + """ + Recursively flatten a nested list structure. + + Args: + nested: The list to flatten. + max_depth: Maximum recursion depth. -1 means unlimited. + + Returns: + A flat list with all nested elements extracted. + """ + result: List[Any] = [] + _flatten_recursive(nested, result, current_depth=0, max_depth=max_depth) + return result + + +_MAX_FLATTEN_DEPTH = 1000 + + +def _flatten_recursive( + items: List[Any], + result: List[Any], + current_depth: int, + max_depth: int, +) -> None: + """Internal recursive helper for flattening nested lists.""" + if current_depth > _MAX_FLATTEN_DEPTH: + raise RecursionError( + f"Flattening exceeded maximum depth of {_MAX_FLATTEN_DEPTH} levels. " + "Input may be too deeply nested." + ) + for item in items: + if isinstance(item, list) and (max_depth == -1 or current_depth < max_depth): + _flatten_recursive(item, result, current_depth + 1, max_depth) + else: + result.append(item) + + +def _deduplicate_list(items: List[Any]) -> List[Any]: + """ + Remove duplicate elements from a list, preserving order of first occurrences. + + Args: + items: The list to deduplicate. + + Returns: + A list with duplicates removed, maintaining original order. + """ + seen: set = set() + result: List[Any] = [] + for item in items: + item_id = _make_hashable(item) + if item_id not in seen: + seen.add(item_id) + result.append(item) + return result + + +def _make_hashable(item: Any): + """ + Create a hashable representation of any item for deduplication. + Converts unhashable types (dicts, lists) into deterministic tuple structures. + """ + if isinstance(item, dict): + return tuple( + sorted( + ((_make_hashable(k), _make_hashable(v)) for k, v in item.items()), + key=lambda x: (str(type(x[0])), str(x[0])), + ) + ) + if isinstance(item, (list, tuple)): + return tuple(_make_hashable(i) for i in item) + if isinstance(item, set): + return frozenset(_make_hashable(i) for i in item) + return item + + +def _filter_none_values(items: List[Any]) -> List[Any]: + """Remove None values from a list.""" + return [item for item in items if item is not None] + + +def _compute_nesting_depth( + items: Any, current: int = 0, max_depth: int = _MAX_FLATTEN_DEPTH +) -> int: + """ + Compute the maximum nesting depth of a list structure using iteration to avoid RecursionError. + + Uses a stack-based approach to handle deeply nested structures without hitting Python's + recursion limit (~1000 levels). + """ + if not isinstance(items, list): + return current + + # Stack contains tuples of (item, depth) + stack = [(items, current)] + max_observed_depth = current + + while stack: + item, depth = stack.pop() + + if depth > max_depth: + return depth + + if not isinstance(item, list): + max_observed_depth = max(max_observed_depth, depth) + continue + + if len(item) == 0: + max_observed_depth = max(max_observed_depth, depth + 1) + continue + + # Add all children to stack with incremented depth + for child in item: + stack.append((child, depth + 1)) + + return max_observed_depth + + +def _interleave_lists(lists: List[List[Any]]) -> List[Any]: + """ + Interleave elements from multiple lists in round-robin fashion. + Example: [[1,2,3], [a,b], [x,y,z]] -> [1, a, x, 2, b, y, 3, z] + """ + if not lists: + return [] + filtered = [lst for lst in lists if lst is not None] + if not filtered: + return [] + result: List[Any] = [] + max_len = max(len(lst) for lst in filtered) + for i in range(max_len): + for lst in filtered: + if i < len(lst): + result.append(lst[i]) + return result + + +# ============================================================================= +# List Concatenation Blocks +# ============================================================================= + + class ConcatenateListsBlock(Block): + """ + Concatenates two or more lists into a single list. + + This block accepts a list of lists and combines all their elements + in order into one flat output list. It supports options for + deduplication and None-filtering to provide flexible list merging + capabilities for workflow pipelines. + """ + class Input(BlockSchemaInput): lists: List[List[Any]] = SchemaField( description="A list of lists to concatenate together. All lists will be combined in order into a single list.", placeholder="e.g., [[1, 2], [3, 4], [5, 6]]", ) + deduplicate: bool = SchemaField( + description="If True, remove duplicate elements from the concatenated result while preserving order.", + default=False, + advanced=True, + ) + remove_none: bool = SchemaField( + description="If True, remove None values from the concatenated result.", + default=False, + advanced=True, + ) class Output(BlockSchemaOutput): concatenated_list: List[Any] = SchemaField( description="The concatenated list containing all elements from all input lists in order." ) + length: int = SchemaField( + description="The total number of elements in the concatenated list." + ) error: str = SchemaField( description="Error message if concatenation failed due to invalid input types." ) @@ -700,7 +902,7 @@ class ConcatenateListsBlock(Block): def __init__(self): super().__init__( id="3cf9298b-5817-4141-9d80-7c2cc5199c8e", - description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order.", + description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order. Supports optional deduplication and None removal.", categories={BlockCategory.BASIC}, input_schema=ConcatenateListsBlock.Input, output_schema=ConcatenateListsBlock.Output, @@ -709,29 +911,497 @@ class ConcatenateListsBlock(Block): {"lists": [["a", "b"], ["c"], ["d", "e", "f"]]}, {"lists": [[1, 2], []]}, {"lists": []}, + {"lists": [[1, 2, 2, 3], [3, 4]], "deduplicate": True}, + {"lists": [[1, None, 2], [None, 3]], "remove_none": True}, ], test_output=[ ("concatenated_list", [1, 2, 3, 4, 5, 6]), + ("length", 6), ("concatenated_list", ["a", "b", "c", "d", "e", "f"]), + ("length", 6), ("concatenated_list", [1, 2]), + ("length", 2), ("concatenated_list", []), + ("length", 0), + ("concatenated_list", [1, 2, 3, 4]), + ("length", 4), + ("concatenated_list", [1, 2, 3]), + ("length", 3), ], ) + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _perform_concatenation(self, lists: List[List[Any]]) -> List[Any]: + return _concatenate_lists_simple(lists) + + def _apply_deduplication(self, items: List[Any]) -> List[Any]: + return _deduplicate_list(items) + + def _apply_none_removal(self, items: List[Any]) -> List[Any]: + return _filter_none_values(items) + + def _post_process( + self, items: List[Any], deduplicate: bool, remove_none: bool + ) -> List[Any]: + """Apply all post-processing steps to the concatenated result.""" + result = items + if remove_none: + result = self._apply_none_removal(result) + if deduplicate: + result = self._apply_deduplication(result) + return result + async def run(self, input_data: Input, **kwargs) -> BlockOutput: - concatenated = [] - for idx, lst in enumerate(input_data.lists): - if lst is None: - # Skip None values to avoid errors - continue - if not isinstance(lst, list): - # Type validation: each item must be a list - # Strings are iterable and would cause extend() to iterate character-by-character - # Non-iterable types would raise TypeError - yield "error", ( - f"Invalid input at index {idx}: expected a list, got {type(lst).__name__}. " - f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])." - ) - return - concatenated.extend(lst) - yield "concatenated_list", concatenated + # Validate all inputs are lists + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + # Perform concatenation + concatenated = self._perform_concatenation(input_data.lists) + + # Apply post-processing + result = self._post_process( + concatenated, input_data.deduplicate, input_data.remove_none + ) + + yield "concatenated_list", result + yield "length", len(result) + + +class FlattenListBlock(Block): + """ + Flattens a nested list structure into a single flat list. + + This block takes a list that may contain nested lists at any depth + and produces a single-level list with all leaf elements. Useful + for normalizing data structures from multiple sources that may + have varying levels of nesting. + """ + + class Input(BlockSchemaInput): + nested_list: List[Any] = SchemaField( + description="A potentially nested list to flatten into a single-level list.", + placeholder="e.g., [[1, [2, 3]], [4, [5, [6]]]]", + ) + max_depth: int = SchemaField( + description="Maximum depth to flatten. -1 means flatten completely. 1 means flatten only one level.", + default=-1, + advanced=True, + ) + + class Output(BlockSchemaOutput): + flattened_list: List[Any] = SchemaField( + description="The flattened list with all nested elements extracted." + ) + length: int = SchemaField( + description="The number of elements in the flattened list." + ) + original_depth: int = SchemaField( + description="The maximum nesting depth of the original input list." + ) + error: str = SchemaField(description="Error message if flattening failed.") + + def __init__(self): + super().__init__( + id="cc45bb0f-d035-4756-96a7-fe3e36254b4d", + description="Flattens a nested list structure into a single flat list. Supports configurable maximum flattening depth.", + categories={BlockCategory.BASIC}, + input_schema=FlattenListBlock.Input, + output_schema=FlattenListBlock.Output, + test_input=[ + {"nested_list": [[1, 2], [3, [4, 5]]]}, + {"nested_list": [1, [2, [3, [4]]]]}, + {"nested_list": [1, [2, [3, [4]]], 5], "max_depth": 1}, + {"nested_list": []}, + {"nested_list": [1, 2, 3]}, + ], + test_output=[ + ("flattened_list", [1, 2, 3, 4, 5]), + ("length", 5), + ("original_depth", 3), + ("flattened_list", [1, 2, 3, 4]), + ("length", 4), + ("original_depth", 4), + ("flattened_list", [1, 2, [3, [4]], 5]), + ("length", 4), + ("original_depth", 4), + ("flattened_list", []), + ("length", 0), + ("original_depth", 1), + ("flattened_list", [1, 2, 3]), + ("length", 3), + ("original_depth", 1), + ], + ) + + def _compute_depth(self, items: List[Any]) -> int: + """Compute the nesting depth of the input list.""" + return _compute_nesting_depth(items) + + def _flatten(self, items: List[Any], max_depth: int) -> List[Any]: + """Flatten the list to the specified depth.""" + return _flatten_nested_list(items, max_depth=max_depth) + + def _validate_max_depth(self, max_depth: int) -> str | None: + """Validate the max_depth parameter.""" + if max_depth < -1: + return f"max_depth must be -1 (unlimited) or a non-negative integer, got {max_depth}" + return None + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + # Validate max_depth + depth_error = self._validate_max_depth(input_data.max_depth) + if depth_error is not None: + yield "error", depth_error + return + + original_depth = self._compute_depth(input_data.nested_list) + flattened = self._flatten(input_data.nested_list, input_data.max_depth) + + yield "flattened_list", flattened + yield "length", len(flattened) + yield "original_depth", original_depth + + +class InterleaveListsBlock(Block): + """ + Interleaves elements from multiple lists in round-robin fashion. + + Given multiple input lists, this block takes one element from each + list in turn, producing an output where elements alternate between + sources. Lists of different lengths are handled gracefully - shorter + lists simply stop contributing once exhausted. + """ + + class Input(BlockSchemaInput): + lists: List[List[Any]] = SchemaField( + description="A list of lists to interleave. Elements will be taken in round-robin order.", + placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]", + ) + + class Output(BlockSchemaOutput): + interleaved_list: List[Any] = SchemaField( + description="The interleaved list with elements alternating from each input list." + ) + length: int = SchemaField( + description="The total number of elements in the interleaved list." + ) + error: str = SchemaField(description="Error message if interleaving failed.") + + def __init__(self): + super().__init__( + id="9f616084-1d9f-4f8e-bc00-5b9d2a75cd75", + description="Interleaves elements from multiple lists in round-robin fashion, alternating between sources.", + categories={BlockCategory.BASIC}, + input_schema=InterleaveListsBlock.Input, + output_schema=InterleaveListsBlock.Output, + test_input=[ + {"lists": [[1, 2, 3], ["a", "b", "c"]]}, + {"lists": [[1, 2, 3], ["a", "b"], ["x", "y", "z"]]}, + {"lists": [[1], [2], [3]]}, + {"lists": []}, + ], + test_output=[ + ("interleaved_list", [1, "a", 2, "b", 3, "c"]), + ("length", 6), + ("interleaved_list", [1, "a", "x", 2, "b", "y", 3, "z"]), + ("length", 8), + ("interleaved_list", [1, 2, 3]), + ("length", 3), + ("interleaved_list", []), + ("length", 0), + ], + ) + + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _interleave(self, lists: List[List[Any]]) -> List[Any]: + return _interleave_lists(lists) + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + result = self._interleave(input_data.lists) + yield "interleaved_list", result + yield "length", len(result) + + +class ZipListsBlock(Block): + """ + Zips multiple lists together into a list of grouped tuples/lists. + + Takes two or more input lists and combines corresponding elements + into sub-lists. For example, zipping [1,2,3] and ['a','b','c'] + produces [[1,'a'], [2,'b'], [3,'c']]. Supports both truncating + to shortest list and padding to longest list with a fill value. + """ + + class Input(BlockSchemaInput): + lists: List[List[Any]] = SchemaField( + description="A list of lists to zip together. Corresponding elements will be grouped.", + placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]", + ) + pad_to_longest: bool = SchemaField( + description="If True, pad shorter lists with fill_value to match the longest list. If False, truncate to shortest.", + default=False, + advanced=True, + ) + fill_value: Any = SchemaField( + description="Value to use for padding when pad_to_longest is True.", + default=None, + advanced=True, + ) + + class Output(BlockSchemaOutput): + zipped_list: List[List[Any]] = SchemaField( + description="The zipped list of grouped elements." + ) + length: int = SchemaField( + description="The number of groups in the zipped result." + ) + error: str = SchemaField(description="Error message if zipping failed.") + + def __init__(self): + super().__init__( + id="0d0e684f-5cb9-4c4b-b8d1-47a0860e0c07", + description="Zips multiple lists together into a list of grouped elements. Supports padding to longest or truncating to shortest.", + categories={BlockCategory.BASIC}, + input_schema=ZipListsBlock.Input, + output_schema=ZipListsBlock.Output, + test_input=[ + {"lists": [[1, 2, 3], ["a", "b", "c"]]}, + {"lists": [[1, 2, 3], ["a", "b"]]}, + { + "lists": [[1, 2], ["a", "b", "c"]], + "pad_to_longest": True, + "fill_value": 0, + }, + {"lists": []}, + ], + test_output=[ + ("zipped_list", [[1, "a"], [2, "b"], [3, "c"]]), + ("length", 3), + ("zipped_list", [[1, "a"], [2, "b"]]), + ("length", 2), + ("zipped_list", [[1, "a"], [2, "b"], [0, "c"]]), + ("length", 3), + ("zipped_list", []), + ("length", 0), + ], + ) + + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _zip_truncate(self, lists: List[List[Any]]) -> List[List[Any]]: + """Zip lists, truncating to shortest.""" + filtered = [lst for lst in lists if lst is not None] + if not filtered: + return [] + return [list(group) for group in zip(*filtered)] + + def _zip_pad(self, lists: List[List[Any]], fill_value: Any) -> List[List[Any]]: + """Zip lists, padding shorter ones with fill_value.""" + if not lists: + return [] + lists = [lst for lst in lists if lst is not None] + if not lists: + return [] + max_len = max(len(lst) for lst in lists) + result: List[List[Any]] = [] + for i in range(max_len): + group: List[Any] = [] + for lst in lists: + if i < len(lst): + group.append(lst[i]) + else: + group.append(fill_value) + result.append(group) + return result + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + if not input_data.lists: + yield "zipped_list", [] + yield "length", 0 + return + + if input_data.pad_to_longest: + result = self._zip_pad(input_data.lists, input_data.fill_value) + else: + result = self._zip_truncate(input_data.lists) + + yield "zipped_list", result + yield "length", len(result) + + +class ListDifferenceBlock(Block): + """ + Computes the difference between two lists (elements in the first + list that are not in the second list). + + This is useful for finding items that exist in one dataset but + not in another, such as finding new items, missing items, or + items that need to be processed. + """ + + class Input(BlockSchemaInput): + list_a: List[Any] = SchemaField( + description="The primary list to check elements from.", + placeholder="e.g., [1, 2, 3, 4, 5]", + ) + list_b: List[Any] = SchemaField( + description="The list to subtract. Elements found here will be removed from list_a.", + placeholder="e.g., [3, 4, 5, 6]", + ) + symmetric: bool = SchemaField( + description="If True, compute symmetric difference (elements in either list but not both).", + default=False, + advanced=True, + ) + + class Output(BlockSchemaOutput): + difference: List[Any] = SchemaField( + description="Elements from list_a not found in list_b (or symmetric difference if enabled)." + ) + length: int = SchemaField( + description="The number of elements in the difference result." + ) + error: str = SchemaField(description="Error message if the operation failed.") + + def __init__(self): + super().__init__( + id="05309873-9d61-447e-96b5-b804e2511829", + description="Computes the difference between two lists. Returns elements in the first list not found in the second, or symmetric difference.", + categories={BlockCategory.BASIC}, + input_schema=ListDifferenceBlock.Input, + output_schema=ListDifferenceBlock.Output, + test_input=[ + {"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]}, + { + "list_a": [1, 2, 3, 4, 5], + "list_b": [3, 4, 5, 6, 7], + "symmetric": True, + }, + {"list_a": ["a", "b", "c"], "list_b": ["b"]}, + {"list_a": [], "list_b": [1, 2, 3]}, + ], + test_output=[ + ("difference", [1, 2]), + ("length", 2), + ("difference", [1, 2, 6, 7]), + ("length", 4), + ("difference", ["a", "c"]), + ("length", 2), + ("difference", []), + ("length", 0), + ], + ) + + def _compute_difference(self, list_a: List[Any], list_b: List[Any]) -> List[Any]: + """Compute elements in list_a not in list_b.""" + b_hashes = {_make_hashable(item) for item in list_b} + return [item for item in list_a if _make_hashable(item) not in b_hashes] + + def _compute_symmetric_difference( + self, list_a: List[Any], list_b: List[Any] + ) -> List[Any]: + """Compute elements in either list but not both.""" + a_hashes = {_make_hashable(item) for item in list_a} + b_hashes = {_make_hashable(item) for item in list_b} + only_in_a = [item for item in list_a if _make_hashable(item) not in b_hashes] + only_in_b = [item for item in list_b if _make_hashable(item) not in a_hashes] + return only_in_a + only_in_b + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + if input_data.symmetric: + result = self._compute_symmetric_difference( + input_data.list_a, input_data.list_b + ) + else: + result = self._compute_difference(input_data.list_a, input_data.list_b) + + yield "difference", result + yield "length", len(result) + + +class ListIntersectionBlock(Block): + """ + Computes the intersection of two lists (elements present in both lists). + + This is useful for finding common items between two datasets, + such as shared tags, mutual connections, or overlapping categories. + """ + + class Input(BlockSchemaInput): + list_a: List[Any] = SchemaField( + description="The first list to intersect.", + placeholder="e.g., [1, 2, 3, 4, 5]", + ) + list_b: List[Any] = SchemaField( + description="The second list to intersect.", + placeholder="e.g., [3, 4, 5, 6, 7]", + ) + + class Output(BlockSchemaOutput): + intersection: List[Any] = SchemaField( + description="Elements present in both list_a and list_b." + ) + length: int = SchemaField( + description="The number of elements in the intersection." + ) + error: str = SchemaField(description="Error message if the operation failed.") + + def __init__(self): + super().__init__( + id="b6eb08b6-dbe3-411b-b9b4-2508cb311a1f", + description="Computes the intersection of two lists, returning only elements present in both.", + categories={BlockCategory.BASIC}, + input_schema=ListIntersectionBlock.Input, + output_schema=ListIntersectionBlock.Output, + test_input=[ + {"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]}, + {"list_a": ["a", "b", "c"], "list_b": ["c", "d", "e"]}, + {"list_a": [1, 2], "list_b": [3, 4]}, + {"list_a": [], "list_b": [1, 2, 3]}, + ], + test_output=[ + ("intersection", [3, 4, 5]), + ("length", 3), + ("intersection", ["c"]), + ("length", 1), + ("intersection", []), + ("length", 0), + ("intersection", []), + ("length", 0), + ], + ) + + def _compute_intersection(self, list_a: List[Any], list_b: List[Any]) -> List[Any]: + """Compute elements present in both lists, preserving order from list_a.""" + b_hashes = {_make_hashable(item) for item in list_b} + seen: set = set() + result: List[Any] = [] + for item in list_a: + h = _make_hashable(item) + if h in b_hashes and h not in seen: + result.append(item) + seen.add(h) + return result + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + result = self._compute_intersection(input_data.list_a, input_data.list_b) + yield "intersection", result + yield "length", len(result) diff --git a/autogpt_platform/backend/backend/blocks/jina/search.py b/autogpt_platform/backend/backend/blocks/jina/search.py index 22a883fa03..5e58ddcab4 100644 --- a/autogpt_platform/backend/backend/blocks/jina/search.py +++ b/autogpt_platform/backend/backend/blocks/jina/search.py @@ -17,6 +17,7 @@ from backend.blocks.jina._auth import ( from backend.blocks.search import GetRequest from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError +from backend.util.request import HTTPClientError, HTTPServerError, validate_url class SearchTheWebBlock(Block, GetRequest): @@ -110,7 +111,12 @@ class ExtractWebsiteContentBlock(Block, GetRequest): self, input_data: Input, *, credentials: JinaCredentials, **kwargs ) -> BlockOutput: if input_data.raw_content: - url = input_data.url + try: + parsed_url, _, _ = await validate_url(input_data.url, []) + url = parsed_url.geturl() + except ValueError as e: + yield "error", f"Invalid URL: {e}" + return headers = {} else: url = f"https://r.jina.ai/{input_data.url}" @@ -119,5 +125,20 @@ class ExtractWebsiteContentBlock(Block, GetRequest): "Authorization": f"Bearer {credentials.api_key.get_secret_value()}", } - content = await self.get_request(url, json=False, headers=headers) + try: + content = await self.get_request(url, json=False, headers=headers) + except HTTPClientError as e: + yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}" + return + except HTTPServerError as e: + yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}" + return + except Exception as e: + yield "error", f"Failed to fetch {input_data.url}: {e}" + return + + if not content: + yield "error", f"No content returned for {input_data.url}" + return + yield "content", content diff --git a/autogpt_platform/backend/backend/blocks/mcp/__init__.py b/autogpt_platform/backend/backend/blocks/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/blocks/mcp/block.py b/autogpt_platform/backend/backend/blocks/mcp/block.py new file mode 100644 index 0000000000..9e3056d928 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/block.py @@ -0,0 +1,300 @@ +""" +MCP (Model Context Protocol) Tool Block. + +A single dynamic block that can connect to any MCP server, discover available tools, +and execute them. Works like AgentExecutorBlock — the user selects a tool from a +dropdown and the input/output schema adapts dynamically. +""" + +import json +import logging +from typing import Any, Literal + +from pydantic import SecretStr + +from backend.blocks._base import ( + Block, + BlockCategory, + BlockSchemaInput, + BlockSchemaOutput, + BlockType, +) +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.data.block import BlockInput, BlockOutput +from backend.data.model import ( + CredentialsField, + CredentialsMetaInput, + OAuth2Credentials, + SchemaField, +) +from backend.integrations.providers import ProviderName +from backend.util.json import validate_with_jsonschema + +logger = logging.getLogger(__name__) + +TEST_CREDENTIALS = OAuth2Credentials( + id="test-mcp-cred", + provider="mcp", + access_token=SecretStr("mock-mcp-token"), + refresh_token=SecretStr("mock-refresh"), + scopes=[], + title="Mock MCP credential", +) +TEST_CREDENTIALS_INPUT = { + "provider": TEST_CREDENTIALS.provider, + "id": TEST_CREDENTIALS.id, + "type": TEST_CREDENTIALS.type, + "title": TEST_CREDENTIALS.title, +} + + +MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]] + + +class MCPToolBlock(Block): + """ + A block that connects to an MCP server, lets the user pick a tool, + and executes it with dynamic input/output schema. + + The flow: + 1. User provides an MCP server URL (and optional credentials) + 2. Frontend calls the backend to get tool list from that URL + 3. User selects a tool from a dropdown (available_tools) + 4. The block's input schema updates to reflect the selected tool's parameters + 5. On execution, the block calls the MCP server to run the tool + """ + + class Input(BlockSchemaInput): + server_url: str = SchemaField( + description="URL of the MCP server (Streamable HTTP endpoint)", + placeholder="https://mcp.example.com/mcp", + ) + credentials: MCPCredentials = CredentialsField( + discriminator="server_url", + description="MCP server OAuth credentials", + default={}, + ) + selected_tool: str = SchemaField( + description="The MCP tool to execute", + placeholder="Select a tool", + default="", + ) + tool_input_schema: dict[str, Any] = SchemaField( + description="JSON Schema for the selected tool's input parameters. " + "Populated automatically when a tool is selected.", + default={}, + hidden=True, + ) + + tool_arguments: dict[str, Any] = SchemaField( + description="Arguments to pass to the selected MCP tool. " + "The fields here are defined by the tool's input schema.", + default={}, + ) + + @classmethod + def get_input_schema(cls, data: BlockInput) -> dict[str, Any]: + """Return the tool's input schema so the builder UI renders dynamic fields.""" + return data.get("tool_input_schema", {}) + + @classmethod + def get_input_defaults(cls, data: BlockInput) -> BlockInput: + """Return the current tool_arguments as defaults for the dynamic fields.""" + return data.get("tool_arguments", {}) + + @classmethod + def get_missing_input(cls, data: BlockInput) -> set[str]: + """Check which required tool arguments are missing.""" + required_fields = cls.get_input_schema(data).get("required", []) + tool_arguments = data.get("tool_arguments", {}) + return set(required_fields) - set(tool_arguments) + + @classmethod + def get_mismatch_error(cls, data: BlockInput) -> str | None: + """Validate tool_arguments against the tool's input schema.""" + tool_schema = cls.get_input_schema(data) + if not tool_schema: + return None + tool_arguments = data.get("tool_arguments", {}) + return validate_with_jsonschema(tool_schema, tool_arguments) + + class Output(BlockSchemaOutput): + result: Any = SchemaField(description="The result returned by the MCP tool") + error: str = SchemaField(description="Error message if the tool call failed") + + def __init__(self): + super().__init__( + id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", + description="Connect to any MCP server and execute its tools. " + "Provide a server URL, select a tool, and pass arguments dynamically.", + categories={BlockCategory.DEVELOPER_TOOLS}, + input_schema=MCPToolBlock.Input, + output_schema=MCPToolBlock.Output, + block_type=BlockType.MCP_TOOL, + test_credentials=TEST_CREDENTIALS, + test_input={ + "server_url": "https://mcp.example.com/mcp", + "credentials": TEST_CREDENTIALS_INPUT, + "selected_tool": "get_weather", + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + }, + test_output=[ + ( + "result", + {"weather": "sunny", "temperature": 20}, + ), + ], + test_mock={ + "_call_mcp_tool": lambda *a, **kw: { + "weather": "sunny", + "temperature": 20, + }, + }, + ) + + async def _call_mcp_tool( + self, + server_url: str, + tool_name: str, + arguments: dict[str, Any], + auth_token: str | None = None, + ) -> Any: + """Call a tool on the MCP server. Extracted for easy mocking in tests.""" + client = MCPClient(server_url, auth_token=auth_token) + await client.initialize() + result = await client.call_tool(tool_name, arguments) + + if result.is_error: + error_text = "" + for item in result.content: + if item.get("type") == "text": + error_text += item.get("text", "") + raise MCPClientError( + f"MCP tool '{tool_name}' returned an error: " + f"{error_text or 'Unknown error'}" + ) + + # Extract text content from the result + output_parts = [] + for item in result.content: + if item.get("type") == "text": + text = item.get("text", "") + # Try to parse as JSON for structured output + try: + output_parts.append(json.loads(text)) + except (json.JSONDecodeError, ValueError): + output_parts.append(text) + elif item.get("type") == "image": + output_parts.append( + { + "type": "image", + "data": item.get("data"), + "mimeType": item.get("mimeType"), + } + ) + elif item.get("type") == "resource": + output_parts.append(item.get("resource", {})) + + # If single result, unwrap + if len(output_parts) == 1: + return output_parts[0] + return output_parts if output_parts else None + + @staticmethod + async def _auto_lookup_credential( + user_id: str, server_url: str + ) -> "OAuth2Credentials | None": + """Auto-lookup stored MCP credential for a server URL. + + This is a fallback for nodes that don't have ``credentials`` explicitly + set (e.g. nodes created before the credential field was wired up). + """ + from backend.integrations.creds_manager import IntegrationCredentialsManager + from backend.integrations.providers import ProviderName + + try: + mgr = IntegrationCredentialsManager() + mcp_creds = await mgr.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + best: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == server_url + ): + if best is None or ( + (cred.access_token_expires_at or 0) + > (best.access_token_expires_at or 0) + ): + best = cred + if best: + best = await mgr.refresh_if_needed(user_id, best) + logger.info( + "Auto-resolved MCP credential %s for %s", best.id, server_url + ) + return best + except Exception: + logger.warning("Auto-lookup MCP credential failed", exc_info=True) + return None + + async def run( + self, + input_data: Input, + *, + user_id: str, + credentials: OAuth2Credentials | None = None, + **kwargs, + ) -> BlockOutput: + if not input_data.server_url: + yield "error", "MCP server URL is required" + return + + if not input_data.selected_tool: + yield "error", "No tool selected. Please select a tool from the dropdown." + return + + # Validate required tool arguments before calling the server. + # The executor-level validation is bypassed for MCP blocks because + # get_input_defaults() flattens tool_arguments, stripping tool_input_schema + # from the validation context. + required = set(input_data.tool_input_schema.get("required", [])) + if required: + missing = required - set(input_data.tool_arguments.keys()) + if missing: + yield "error", ( + f"Missing required argument(s): {', '.join(sorted(missing))}. " + f"Please fill in all required fields marked with * in the block form." + ) + return + + # If no credentials were injected by the executor (e.g. legacy nodes + # that don't have the credentials field set), try to auto-lookup + # the stored MCP credential for this server URL. + if credentials is None: + credentials = await self._auto_lookup_credential( + user_id, input_data.server_url + ) + + auth_token = ( + credentials.access_token.get_secret_value() if credentials else None + ) + + try: + result = await self._call_mcp_tool( + server_url=input_data.server_url, + tool_name=input_data.selected_tool, + arguments=input_data.tool_arguments, + auth_token=auth_token, + ) + yield "result", result + except MCPClientError as e: + yield "error", str(e) + except Exception as e: + logger.exception(f"MCP tool call failed: {e}") + yield "error", f"MCP tool call failed: {str(e)}" diff --git a/autogpt_platform/backend/backend/blocks/mcp/client.py b/autogpt_platform/backend/backend/blocks/mcp/client.py new file mode 100644 index 0000000000..050349dbcc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/client.py @@ -0,0 +1,323 @@ +""" +MCP (Model Context Protocol) HTTP client. + +Implements the MCP Streamable HTTP transport for listing tools and calling tools +on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST. + +Handles both JSON and SSE (text/event-stream) response formats per the MCP spec. + +Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports +""" + +import json +import logging +from dataclasses import dataclass, field +from typing import Any + +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +@dataclass +class MCPTool: + """Represents an MCP tool discovered from a server.""" + + name: str + description: str + input_schema: dict[str, Any] + + +@dataclass +class MCPCallResult: + """Result from calling an MCP tool.""" + + content: list[dict[str, Any]] = field(default_factory=list) + is_error: bool = False + + +class MCPClientError(Exception): + """Raised when an MCP protocol error occurs.""" + + pass + + +class MCPClient: + """ + Async HTTP client for the MCP Streamable HTTP transport. + + Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST. + Supports optional Bearer token authentication. + """ + + def __init__( + self, + server_url: str, + auth_token: str | None = None, + ): + self.server_url = server_url.rstrip("/") + self.auth_token = auth_token + self._request_id = 0 + self._session_id: str | None = None + + def _next_id(self) -> int: + self._request_id += 1 + return self._request_id + + def _build_headers(self) -> dict[str, str]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + if self._session_id: + headers["Mcp-Session-Id"] = self._session_id + return headers + + def _build_jsonrpc_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + req: dict[str, Any] = { + "jsonrpc": "2.0", + "method": method, + "id": self._next_id(), + } + if params is not None: + req["params"] = params + return req + + @staticmethod + def _parse_sse_response(text: str) -> dict[str, Any]: + """Parse an SSE (text/event-stream) response body into JSON-RPC data. + + MCP servers may return responses as SSE with format: + event: message + data: {"jsonrpc":"2.0","result":{...},"id":1} + + We extract the last `data:` line that contains a JSON-RPC response + (i.e. has an "id" field), which is the reply to our request. + """ + last_data: dict[str, Any] | None = None + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("data:"): + payload = stripped[len("data:") :].strip() + if not payload: + continue + try: + parsed = json.loads(payload) + # Only keep JSON-RPC responses (have "id"), skip notifications + if isinstance(parsed, dict) and "id" in parsed: + last_data = parsed + except (json.JSONDecodeError, ValueError): + continue + if last_data is None: + raise MCPClientError("No JSON-RPC response found in SSE stream") + return last_data + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> Any: + """Send a JSON-RPC request to the MCP server and return the result. + + Handles both ``application/json`` and ``text/event-stream`` responses + as required by the MCP Streamable HTTP transport specification. + """ + payload = self._build_jsonrpc_request(method, params) + headers = self._build_headers() + + requests = Requests( + raise_for_status=True, + extra_headers=headers, + ) + response = await requests.post(self.server_url, json=payload) + + # Capture session ID from response (MCP Streamable HTTP transport) + session_id = response.headers.get("Mcp-Session-Id") + if session_id: + self._session_id = session_id + + content_type = response.headers.get("content-type", "") + if "text/event-stream" in content_type: + body = self._parse_sse_response(response.text()) + else: + try: + body = response.json() + except Exception as e: + raise MCPClientError( + f"MCP server returned non-JSON response: {e}" + ) from e + + if not isinstance(body, dict): + raise MCPClientError( + f"MCP server returned unexpected JSON type: {type(body).__name__}" + ) + + # Handle JSON-RPC error + if "error" in body: + error = body["error"] + if isinstance(error, dict): + raise MCPClientError( + f"MCP server error [{error.get('code', '?')}]: " + f"{error.get('message', 'Unknown error')}" + ) + raise MCPClientError(f"MCP server error: {error}") + + return body.get("result") + + async def _send_notification(self, method: str) -> None: + """Send a JSON-RPC notification (no id, no response expected).""" + headers = self._build_headers() + notification = {"jsonrpc": "2.0", "method": method} + requests = Requests( + raise_for_status=False, + extra_headers=headers, + ) + await requests.post(self.server_url, json=notification) + + async def discover_auth(self) -> dict[str, Any] | None: + """Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec). + + Returns ``None`` if the server doesn't require auth, otherwise returns + a dict with: + - ``authorization_servers``: list of authorization server URLs + - ``resource``: the resource indicator URL (usually the MCP endpoint) + - ``scopes_supported``: optional list of supported scopes + + The caller can then fetch the authorization server metadata to get + ``authorization_endpoint``, ``token_endpoint``, etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(self.server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + + # Build candidates for protected-resource metadata (per RFC 9728) + path = parsed.path.rstrip("/") + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-protected-resource{path}") + candidates.append(f"{base}/.well-known/oauth-protected-resource") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_servers" in data: + return data + except Exception: + continue + + return None + + async def discover_auth_server_metadata( + self, auth_server_url: str + ) -> dict[str, Any] | None: + """Fetch the OAuth Authorization Server Metadata (RFC 8414). + + Given an authorization server URL, returns a dict with: + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``registration_endpoint`` (for dynamic client registration) + - ``scopes_supported`` + - ``code_challenge_methods_supported`` + - etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(auth_server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + + # Try standard metadata endpoints (RFC 8414 and OpenID Connect) + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-authorization-server{path}") + candidates.append(f"{base}/.well-known/oauth-authorization-server") + candidates.append(f"{base}/.well-known/openid-configuration") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_endpoint" in data: + return data + except Exception: + continue + + return None + + async def initialize(self) -> dict[str, Any]: + """ + Send the MCP initialize request. + + This is required by the MCP protocol before any other requests. + Returns the server's capabilities. + """ + result = await self._send_request( + "initialize", + { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"}, + }, + ) + # Send initialized notification (no response expected) + await self._send_notification("notifications/initialized") + + return result or {} + + async def list_tools(self) -> list[MCPTool]: + """ + Discover available tools from the MCP server. + + Returns a list of MCPTool objects with name, description, and input schema. + """ + result = await self._send_request("tools/list") + if not result or "tools" not in result: + return [] + + tools = [] + for tool_data in result["tools"]: + tools.append( + MCPTool( + name=tool_data.get("name", ""), + description=tool_data.get("description", ""), + input_schema=tool_data.get("inputSchema", {}), + ) + ) + return tools + + async def call_tool( + self, tool_name: str, arguments: dict[str, Any] + ) -> MCPCallResult: + """ + Call a tool on the MCP server. + + Args: + tool_name: The name of the tool to call. + arguments: The arguments to pass to the tool. + + Returns: + MCPCallResult with the tool's response content. + """ + result = await self._send_request( + "tools/call", + {"name": tool_name, "arguments": arguments}, + ) + if not result: + return MCPCallResult(is_error=True) + + return MCPCallResult( + content=result.get("content", []), + is_error=result.get("isError", False), + ) diff --git a/autogpt_platform/backend/backend/blocks/mcp/oauth.py b/autogpt_platform/backend/backend/blocks/mcp/oauth.py new file mode 100644 index 0000000000..2228336cd3 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/oauth.py @@ -0,0 +1,204 @@ +""" +MCP OAuth handler for MCP servers that use OAuth 2.1 authorization. + +Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed, +MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata. +This handler accepts those endpoints at construction time. +""" + +import logging +import time +import urllib.parse +from typing import ClassVar, Optional + +from pydantic import SecretStr + +from backend.data.model import OAuth2Credentials +from backend.integrations.oauth.base import BaseOAuthHandler +from backend.integrations.providers import ProviderName +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +class MCPOAuthHandler(BaseOAuthHandler): + """ + OAuth handler for MCP servers with dynamically-discovered endpoints. + + Construction requires the authorization and token endpoint URLs, + which are obtained via MCP OAuth metadata discovery + (``MCPClient.discover_auth`` + ``discover_auth_server_metadata``). + """ + + PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP + DEFAULT_SCOPES: ClassVar[list[str]] = [] + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_uri: str, + *, + authorize_url: str, + token_url: str, + revoke_url: str | None = None, + resource_url: str | None = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + self.authorize_url = authorize_url + self.token_url = token_url + self.revoke_url = revoke_url + self.resource_url = resource_url + + def get_login_url( + self, + scopes: list[str], + state: str, + code_challenge: Optional[str], + ) -> str: + scopes = self.handle_default_scopes(scopes) + + params: dict[str, str] = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "state": state, + } + if scopes: + params["scope"] = " ".join(scopes) + # PKCE (S256) — included when the caller provides a code_challenge + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + # MCP spec requires resource indicator (RFC 8707) + if self.resource_url: + params["resource"] = self.resource_url + + return f"{self.authorize_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_tokens( + self, + code: str, + scopes: list[str], + code_verifier: Optional[str], + ) -> OAuth2Credentials: + data: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if code_verifier: + data["code_verifier"] = code_verifier + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token exchange failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth token response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + provider=self.PROVIDER_NAME, + title=None, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else None + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=None, + scopes=scopes, + metadata={ + "mcp_token_url": self.token_url, + "mcp_resource_url": self.resource_url, + }, + ) + + async def _refresh_tokens( + self, credentials: OAuth2Credentials + ) -> OAuth2Credentials: + if not credentials.refresh_token: + raise ValueError("No refresh token available for MCP OAuth credentials") + + data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token.get_secret_value(), + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token refresh failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth refresh response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + id=credentials.id, + provider=self.PROVIDER_NAME, + title=credentials.title, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else credentials.refresh_token + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=credentials.refresh_token_expires_at, + scopes=credentials.scopes, + metadata=credentials.metadata, + ) + + async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: + if not self.revoke_url: + return False + + try: + data = { + "token": credentials.access_token.get_secret_value(), + "token_type_hint": "access_token", + "client_id": self.client_id, + } + await Requests().post( + self.revoke_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return True + except Exception: + logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True) + return False diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py new file mode 100644 index 0000000000..7818fac9ce --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py @@ -0,0 +1,109 @@ +""" +End-to-end tests against a real public MCP server. + +These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp) +which is publicly accessible without authentication and returns SSE responses. + +Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped +independently of the rest of the test suite (they require network access). +""" + +import json +import os + +import pytest + +from backend.blocks.mcp.client import MCPClient + +# Public MCP server that requires no authentication +OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp" + +# Skip all tests in this module unless RUN_E2E env var is set +pytestmark = pytest.mark.skipif( + not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests" +) + + +class TestRealMCPServer: + """Tests against the live OpenAI docs MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + """Verify we can complete the MCP handshake with a real server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert "serverInfo" in result + assert result["serverInfo"]["name"] == "openai-docs-mcp" + assert "tools" in result.get("capabilities", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + """Verify we can discover tools from a real MCP server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) >= 3 # server has at least 5 tools as of writing + + tool_names = {t.name for t in tools} + # These tools are documented and should be stable + assert "search_openai_docs" in tool_names + assert "list_openai_docs" in tool_names + assert "fetch_openai_doc" in tool_names + + # Verify schema structure + search_tool = next(t for t in tools if t.name == "search_openai_docs") + assert "query" in search_tool.input_schema.get("properties", {}) + assert "query" in search_tool.input_schema.get("required", []) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_list_api_endpoints(self): + """Call the list_api_endpoints tool and verify we get real data.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool("list_api_endpoints", {}) + + assert not result.is_error + assert len(result.content) >= 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert "paths" in data or "urls" in data + # The OpenAI API should have many endpoints + total = data.get("total", len(data.get("paths", []))) + assert total > 50 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_search(self): + """Search for docs and verify we get results.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool( + "search_openai_docs", {"query": "chat completions", "limit": 3} + ) + + assert not result.is_error + assert len(result.content) >= 1 + + @pytest.mark.asyncio(loop_scope="session") + async def test_sse_response_handling(self): + """Verify the client correctly handles SSE responses from a real server. + + This is the key test — our local test server returns JSON, + but real MCP servers typically return SSE. This proves the + SSE parsing works end-to-end. + """ + client = MCPClient(OPENAI_DOCS_MCP_URL) + # initialize() internally calls _send_request which must parse SSE + result = await client.initialize() + + # If we got here without error, SSE parsing works + assert isinstance(result, dict) + assert "protocolVersion" in result + + # Also verify list_tools works (another SSE response) + tools = await client.list_tools() + assert len(tools) > 0 + assert all(hasattr(t, "name") for t in tools) diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_integration.py b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py new file mode 100644 index 0000000000..70658dbaaf --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py @@ -0,0 +1,389 @@ +""" +Integration tests for MCP client and MCPToolBlock against a real HTTP server. + +These tests spin up a local MCP test server and run the full client/block flow +against it — no mocking, real HTTP requests. +""" + +import asyncio +import json +import threading +from unittest.mock import patch + +import pytest +from aiohttp import web +from pydantic import SecretStr + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.test_server import create_test_mcp_app +from backend.data.model import OAuth2Credentials + +MOCK_USER_ID = "test-user-integration" + + +class _MCPTestServer: + """ + Run an MCP test server in a background thread with its own event loop. + This avoids event loop conflicts with pytest-asyncio. + """ + + def __init__(self, auth_token: str | None = None): + self.auth_token = auth_token + self.url: str = "" + self._runner: web.AppRunner | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + self._started = threading.Event() + + def _run(self): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self._start()) + self._started.set() + self._loop.run_forever() + + async def _start(self): + app = create_test_mcp_app(auth_token=self.auth_token) + self._runner = web.AppRunner(app) + await self._runner.setup() + site = web.TCPSite(self._runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] + self.url = f"http://127.0.0.1:{port}/mcp" + + def start(self): + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + if not self._started.wait(timeout=5): + raise RuntimeError("MCP test server failed to start within 5 seconds") + return self + + def stop(self): + if self._loop and self._runner: + asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result( + timeout=5 + ) + self._loop.call_soon_threadsafe(self._loop.stop) + if self._thread: + self._thread.join(timeout=5) + + +@pytest.fixture(scope="module") +def mcp_server(): + """Start a local MCP test server in a background thread.""" + server = _MCPTestServer() + server.start() + yield server.url + server.stop() + + +@pytest.fixture(scope="module") +def mcp_server_with_auth(): + """Start a local MCP test server with auth in a background thread.""" + server = _MCPTestServer(auth_token="test-secret-token") + server.start() + yield server.url, "test-secret-token" + server.stop() + + +@pytest.fixture(autouse=True) +def _allow_localhost(): + """ + Allow 127.0.0.1 through SSRF protection for integration tests. + + The Requests class blocks private IPs by default. We patch the Requests + constructor to always include 127.0.0.1 as a trusted origin so the local + test server is reachable. + """ + from backend.util.request import Requests + + original_init = Requests.__init__ + + def patched_init(self, *args, **kwargs): + trusted = list(kwargs.get("trusted_origins") or []) + trusted.append("http://127.0.0.1") + kwargs["trusted_origins"] = trusted + original_init(self, *args, **kwargs) + + with patch.object(Requests, "__init__", patched_init): + yield + + +def _make_client(url: str, auth_token: str | None = None) -> MCPClient: + """Create an MCPClient for integration tests.""" + return MCPClient(url, auth_token=auth_token) + + +# ── MCPClient integration tests ────────────────────────────────────── + + +class TestMCPClientIntegration: + """Test MCPClient against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self, mcp_server): + client = _make_client(mcp_server) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert result["serverInfo"]["name"] == "test-mcp-server" + assert "tools" in result["capabilities"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) == 3 + + tool_names = {t.name for t in tools} + assert tool_names == {"get_weather", "add_numbers", "echo"} + + # Check get_weather schema + weather = next(t for t in tools if t.name == "get_weather") + assert weather.description == "Get current weather for a city" + assert "city" in weather.input_schema["properties"] + assert weather.input_schema["required"] == ["city"] + + # Check add_numbers schema + add = next(t for t in tools if t.name == "add_numbers") + assert "a" in add.input_schema["properties"] + assert "b" in add.input_schema["properties"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_get_weather(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert data["city"] == "London" + assert data["temperature"] == 22 + assert data["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_add_numbers(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("add_numbers", {"a": 3, "b": 7}) + + assert not result.is_error + data = json.loads(result.content[0]["text"]) + assert data["result"] == 10 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_echo(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("echo", {"message": "Hello MCP!"}) + + assert not result.is_error + assert result.content[0]["text"] == "Hello MCP!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_unknown_tool(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("nonexistent_tool", {}) + + assert result.is_error + assert "Unknown tool" in result.content[0]["text"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_success(self, mcp_server_with_auth): + url, token = mcp_server_with_auth + client = _make_client(url, auth_token=token) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + + tools = await client.list_tools() + assert len(tools) == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_failure(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url, auth_token="wrong-token") + + with pytest.raises(Exception): + await client.initialize() + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_missing(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url) + + with pytest.raises(Exception): + await client.initialize() + + +# ── MCPToolBlock integration tests ─────────────────────────────────── + + +class TestMCPToolBlockIntegration: + """Test MCPToolBlock end-to-end against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_get_weather(self, mcp_server): + """Full flow: discover tools, select one, execute it.""" + # Step 1: Discover tools (simulating what the frontend/API would do) + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + assert len(tools) == 3 + + # Step 2: User selects "get_weather" and we get its schema + weather_tool = next(t for t in tools if t.name == "get_weather") + + # Step 3: Execute the block — no credentials (public server) + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="get_weather", + tool_input_schema=weather_tool.input_schema, + tool_arguments={"city": "Paris"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + result = outputs[0][1] + assert result["city"] == "Paris" + assert result["temperature"] == 22 + assert result["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_add_numbers(self, mcp_server): + """Full flow for add_numbers tool.""" + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + add_tool = next(t for t in tools if t.name == "add_numbers") + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="add_numbers", + tool_input_schema=add_tool.input_schema, + tool_arguments={"a": 42, "b": 58}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1]["result"] == 100 + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_echo_plain_text(self, mcp_server): + """Verify plain text (non-JSON) responses work.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Hello from AutoGPT!"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Hello from AutoGPT!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_unknown_tool_yields_error(self, mcp_server): + """Calling an unknown tool should yield an error output.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="nonexistent_tool", + tool_arguments={}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "error" + assert "returned an error" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_with_auth(self, mcp_server_with_auth): + """Full flow with authentication via credentials kwarg.""" + url, token = mcp_server_with_auth + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=url, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Authenticated!"}, + ) + + # Pass credentials via the standard kwarg (as the executor would) + test_creds = OAuth2Credentials( + id="test-cred", + provider="mcp", + access_token=SecretStr(token), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Authenticated!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_credentials_runs_without_auth(self, mcp_server): + """Block runs without auth when no credentials are provided.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "No auth needed"}, + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=None + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "No auth needed" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py new file mode 100644 index 0000000000..8cb49b0fee --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py @@ -0,0 +1,619 @@ +""" +Tests for MCP client and MCPToolBlock. +""" + +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError +from backend.util.test import execute_block_test + +# ── SSE parsing unit tests ─────────────────────────────────────────── + + +class TestSSEParsing: + """Tests for SSE (text/event-stream) response parsing.""" + + def test_parse_sse_simple(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"tools": []} + assert body["id"] == 1 + + def test_parse_sse_with_notifications(self): + """SSE streams can contain notifications (no id) before the response.""" + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","method":"some/notification"}\n' + "\n" + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"ok": True} + assert body["id"] == 2 + + def test_parse_sse_error_response(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n' + ) + body = MCPClient._parse_sse_response(sse) + assert "error" in body + assert body["error"]["code"] == -32600 + + def test_parse_sse_no_data_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("event: message\n\n") + + def test_parse_sse_empty_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("") + + def test_parse_sse_ignores_non_data_lines(self): + sse = ( + ": comment line\n" + "event: message\n" + "id: 123\n" + 'data: {"jsonrpc":"2.0","result":"ok","id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "ok" + + def test_parse_sse_uses_last_response(self): + """If multiple responses exist, use the last one.""" + sse = ( + 'data: {"jsonrpc":"2.0","result":"first","id":1}\n' + "\n" + 'data: {"jsonrpc":"2.0","result":"second","id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "second" + + +# ── MCPClient unit tests ───────────────────────────────────────────── + + +class TestMCPClient: + """Tests for the MCP HTTP client.""" + + def test_build_headers_without_auth(self): + client = MCPClient("https://mcp.example.com") + headers = client._build_headers() + assert "Authorization" not in headers + assert headers["Content-Type"] == "application/json" + + def test_build_headers_with_auth(self): + client = MCPClient("https://mcp.example.com", auth_token="my-token") + headers = client._build_headers() + assert headers["Authorization"] == "Bearer my-token" + + def test_build_jsonrpc_request(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request("tools/list") + assert req["jsonrpc"] == "2.0" + assert req["method"] == "tools/list" + assert "id" in req + assert "params" not in req + + def test_build_jsonrpc_request_with_params(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request( + "tools/call", {"name": "test", "arguments": {"x": 1}} + ) + assert req["params"] == {"name": "test", "arguments": {"x": 1}} + + def test_request_id_increments(self): + client = MCPClient("https://mcp.example.com") + req1 = client._build_jsonrpc_request("tools/list") + req2 = client._build_jsonrpc_request("tools/list") + assert req2["id"] > req1["id"] + + def test_server_url_trailing_slash_stripped(self): + client = MCPClient("https://mcp.example.com/mcp/") + assert client.server_url == "https://mcp.example.com/mcp" + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_success(self): + client = MCPClient("https://mcp.example.com") + + mock_response = AsyncMock() + mock_response.json.return_value = { + "jsonrpc": "2.0", + "result": {"tools": []}, + "id": 1, + } + + with patch.object(client, "_send_request", return_value={"tools": []}): + result = await client._send_request("tools/list") + assert result == {"tools": []} + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_error(self): + client = MCPClient("https://mcp.example.com") + + async def mock_send(*args, **kwargs): + raise MCPClientError("MCP server error [-32600]: Invalid Request") + + with patch.object(client, "_send_request", side_effect=mock_send): + with pytest.raises(MCPClientError, match="Invalid Request"): + await client._send_request("tools/list") + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + { + "name": "search", + "description": "Search the web", + "inputSchema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + ] + } + + with patch.object(client, "_send_request", return_value=mock_result): + tools = await client.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "get_weather" + assert tools[0].description == "Get current weather for a city" + assert tools[0].input_schema["properties"]["city"]["type"] == "string" + assert tools[1].name == "search" + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_empty(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value={"tools": []}): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_success(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [ + {"type": "text", "text": json.dumps({"temp": 20, "city": "London"})} + ], + "isError": False, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_error(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [{"type": "text", "text": "City not found"}], + "isError": True, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "???"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "test-server", "version": "1.0.0"}, + } + + with ( + patch.object(client, "_send_request", return_value=mock_result) as mock_req, + patch.object(client, "_send_notification") as mock_notif, + ): + result = await client.initialize() + + mock_req.assert_called_once() + mock_notif.assert_called_once_with("notifications/initialized") + assert result["protocolVersion"] == "2025-03-26" + + +# ── MCPToolBlock unit tests ────────────────────────────────────────── + +MOCK_USER_ID = "test-user-123" + + +class TestMCPToolBlock: + """Tests for the MCPToolBlock.""" + + def test_block_instantiation(self): + block = MCPToolBlock() + assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4" + assert block.name == "MCPToolBlock" + + def test_input_schema_has_required_fields(self): + block = MCPToolBlock() + schema = block.input_schema.jsonschema() + props = schema.get("properties", {}) + assert "server_url" in props + assert "selected_tool" in props + assert "tool_arguments" in props + assert "credentials" in props + + def test_output_schema(self): + block = MCPToolBlock() + schema = block.output_schema.jsonschema() + props = schema.get("properties", {}) + assert "result" in props + assert "error" in props + + def test_get_input_schema_with_tool_schema(self): + tool_schema = { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + } + data = {"tool_input_schema": tool_schema} + result = MCPToolBlock.Input.get_input_schema(data) + assert result == tool_schema + + def test_get_input_schema_without_tool_schema(self): + result = MCPToolBlock.Input.get_input_schema({}) + assert result == {} + + def test_get_input_defaults(self): + data = {"tool_arguments": {"city": "London"}} + result = MCPToolBlock.Input.get_input_defaults(data) + assert result == {"city": "London"} + + def test_get_missing_input(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "units": {"type": "string"}, + }, + "required": ["city", "units"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == {"units"} + + def test_get_missing_input_all_present(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == set() + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_mock(self): + """Test the block using the built-in test infrastructure.""" + block = MCPToolBlock() + await execute_block_test(block) + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_server_url(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="", + selected_tool="test", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [("error", "MCP server URL is required")] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_tool(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [ + ("error", "No tool selected. Please select a tool from the dropdown.") + ] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_success(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="get_weather", + tool_input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + tool_arguments={"city": "London"}, + ) + + async def mock_call(*args, **kwargs): + return {"temp": 20, "city": "London"} + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == {"temp": 20, "city": "London"} + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_mcp_error(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="bad_tool", + ) + + async def mock_call(*args, **kwargs): + raise MCPClientError("Tool not found") + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert outputs[0][0] == "error" + assert "Tool not found" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_parses_json_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": '{"temp": 20}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == {"temp": 20} + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_plain_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Hello, world!"}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == "Hello, world!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_multiple_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Part 1"}, + {"type": "text", "text": '{"part": 2}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == ["Part 1", {"part": 2}] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_error_result(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[{"type": "text", "text": "Something went wrong"}], + is_error=True, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + with pytest.raises(MCPClientError, match="returned an error"): + await block._call_mcp_tool("https://mcp.example.com", "test_tool", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_image_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_credentials(self): + """Verify the block uses OAuth2Credentials and passes auth token.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + test_creds = OAuth2Credentials( + id="cred-123", + provider="mcp", + access_token=SecretStr("resolved-token"), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + async for _ in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + pass + + assert captured_tokens == ["resolved-token"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_without_credentials(self): + """Verify the block works without credentials (public server).""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert captured_tokens == [None] + assert outputs == [("result", "ok")] diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py new file mode 100644 index 0000000000..e9a42f68ea --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py @@ -0,0 +1,242 @@ +""" +Tests for MCP OAuth handler. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import SecretStr + +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials + + +def _mock_response(json_data: dict, status: int = 200) -> MagicMock: + """Create a mock Response with synchronous json() (matching Requests.Response).""" + resp = MagicMock() + resp.status = status + resp.ok = 200 <= status < 300 + resp.json.return_value = json_data + return resp + + +class TestMCPOAuthHandler: + """Tests for the MCPOAuthHandler.""" + + def _make_handler(self, **overrides) -> MCPOAuthHandler: + defaults = { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "redirect_uri": "https://app.example.com/callback", + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + } + defaults.update(overrides) + return MCPOAuthHandler(**defaults) + + def test_get_login_url_basic(self): + handler = self._make_handler() + url = handler.get_login_url( + scopes=["read", "write"], + state="random-state-token", + code_challenge="S256-challenge-value", + ) + + assert "https://auth.example.com/authorize?" in url + assert "response_type=code" in url + assert "client_id=test-client-id" in url + assert "state=random-state-token" in url + assert "code_challenge=S256-challenge-value" in url + assert "code_challenge_method=S256" in url + assert "scope=read+write" in url + + def test_get_login_url_with_resource(self): + handler = self._make_handler(resource_url="https://mcp.example.com/mcp") + url = handler.get_login_url( + scopes=[], state="state", code_challenge="challenge" + ) + + assert "resource=https" in url + + def test_get_login_url_without_pkce(self): + handler = self._make_handler() + url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None) + + assert "code_challenge" not in url + assert "code_challenge_method" not in url + + @pytest.mark.asyncio(loop_scope="session") + async def test_exchange_code_for_tokens(self): + handler = self._make_handler() + + resp = _mock_response( + { + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + "token_type": "Bearer", + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + creds = await handler.exchange_code_for_tokens( + code="auth-code", + scopes=["read"], + code_verifier="pkce-verifier", + ) + + assert isinstance(creds, OAuth2Credentials) + assert creds.access_token.get_secret_value() == "new-access-token" + assert creds.refresh_token is not None + assert creds.refresh_token.get_secret_value() == "new-refresh-token" + assert creds.scopes == ["read"] + assert creds.access_token_expires_at is not None + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens(self): + handler = self._make_handler() + + existing_creds = OAuth2Credentials( + id="existing-id", + provider="mcp", + access_token=SecretStr("old-token"), + refresh_token=SecretStr("old-refresh"), + scopes=["read"], + title="test", + ) + + resp = _mock_response( + { + "access_token": "refreshed-token", + "refresh_token": "new-refresh", + "expires_in": 3600, + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + refreshed = await handler._refresh_tokens(existing_creds) + + assert refreshed.id == "existing-id" + assert refreshed.access_token.get_secret_value() == "refreshed-token" + assert refreshed.refresh_token is not None + assert refreshed.refresh_token.get_secret_value() == "new-refresh" + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens_no_refresh_token(self): + handler = self._make_handler() + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=["read"], + title="test", + ) + + with pytest.raises(ValueError, match="No refresh token"): + await handler._refresh_tokens(creds) + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_no_url(self): + handler = self._make_handler(revoke_url=None) + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + result = await handler.revoke_tokens(creds) + assert result is False + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_with_url(self): + handler = self._make_handler(revoke_url="https://auth.example.com/revoke") + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + resp = _mock_response({}, status=200) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + result = await handler.revoke_tokens(creds) + + assert result is True + + +class TestMCPClientDiscovery: + """Tests for MCPClient OAuth metadata discovery.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + metadata = { + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + + resp = _mock_response(metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is not None + assert result["authorization_servers"] == ["https://auth.example.com"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_not_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + resp = _mock_response({}, status=404) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_server_metadata(self): + client = MCPClient("https://mcp.example.com/mcp") + + server_metadata = { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + "code_challenge_methods_supported": ["S256"], + } + + resp = _mock_response(server_metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth_server_metadata( + "https://auth.example.com" + ) + + assert result is not None + assert result["authorization_endpoint"] == "https://auth.example.com/authorize" + assert result["token_endpoint"] == "https://auth.example.com/token" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_server.py b/autogpt_platform/backend/backend/blocks/mcp/test_server.py new file mode 100644 index 0000000000..a6732932bc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_server.py @@ -0,0 +1,162 @@ +""" +Minimal MCP server for integration testing. + +Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST) +with a few sample tools. Runs on localhost with a random available port. +""" + +import json +import logging + +from aiohttp import web + +logger = logging.getLogger(__name__) + +# Sample tools this test server exposes +TEST_TOOLS = [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + }, + }, + "required": ["city"], + }, + }, + { + "name": "add_numbers", + "description": "Add two numbers together", + "inputSchema": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + { + "name": "echo", + "description": "Echo back the input message", + "inputSchema": { + "type": "object", + "properties": { + "message": {"type": "string", "description": "Message to echo"}, + }, + "required": ["message"], + }, + }, +] + + +def _handle_initialize(params: dict) -> dict: + return { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "test-mcp-server", "version": "1.0.0"}, + } + + +def _handle_tools_list(params: dict) -> dict: + return {"tools": TEST_TOOLS} + + +def _handle_tools_call(params: dict) -> dict: + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + if tool_name == "get_weather": + city = arguments.get("city", "Unknown") + return { + "content": [ + { + "type": "text", + "text": json.dumps( + {"city": city, "temperature": 22, "condition": "sunny"} + ), + } + ], + } + + elif tool_name == "add_numbers": + a = arguments.get("a", 0) + b = arguments.get("b", 0) + return { + "content": [{"type": "text", "text": json.dumps({"result": a + b})}], + } + + elif tool_name == "echo": + message = arguments.get("message", "") + return { + "content": [{"type": "text", "text": message}], + } + + else: + return { + "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], + "isError": True, + } + + +HANDLERS = { + "initialize": _handle_initialize, + "tools/list": _handle_tools_list, + "tools/call": _handle_tools_call, +} + + +async def handle_mcp_request(request: web.Request) -> web.Response: + """Handle incoming MCP JSON-RPC 2.0 requests.""" + # Check auth if configured + expected_token = request.app.get("auth_token") + if expected_token: + auth_header = request.headers.get("Authorization", "") + if auth_header != f"Bearer {expected_token}": + return web.json_response( + { + "jsonrpc": "2.0", + "error": {"code": -32001, "message": "Unauthorized"}, + "id": None, + }, + status=401, + ) + + body = await request.json() + + # Handle notifications (no id field) — just acknowledge + if "id" not in body: + return web.Response(status=202) + + method = body.get("method", "") + params = body.get("params", {}) + request_id = body.get("id") + + handler = HANDLERS.get(method) + if not handler: + return web.json_response( + { + "jsonrpc": "2.0", + "error": { + "code": -32601, + "message": f"Method not found: {method}", + }, + "id": request_id, + } + ) + + result = handler(params) + return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id}) + + +def create_test_mcp_app(auth_token: str | None = None) -> web.Application: + """Create an aiohttp app that acts as an MCP server.""" + app = web.Application() + app.router.add_post("/mcp", handle_mcp_request) + if auth_token: + app["auth_token"] = auth_token + return app diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index f39a0144e7..94f99852e8 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -33,6 +33,7 @@ from backend.util import type as type_utils from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError from backend.util.json import SafeJson from backend.util.models import Pagination +from backend.util.request import parse_url from .block import BlockInput from .db import BaseDbModel @@ -449,6 +450,9 @@ class GraphModel(Graph, GraphMeta): continue if ProviderName.HTTP in field.provider: continue + # MCP credentials are intentionally split by server URL + if ProviderName.MCP in field.provider: + continue # If this happens, that means a block implementation probably needs # to be updated. @@ -505,6 +509,18 @@ class GraphModel(Graph, GraphMeta): "required": ["id", "provider", "type"], } + # Add a descriptive display title when URL-based discriminator values + # are present (e.g. "mcp.sentry.dev" instead of just "Mcp") + if ( + field_info.discriminator + and not field_info.discriminator_mapping + and field_info.discriminator_values + ): + hostnames = sorted( + parse_url(str(v)).netloc for v in field_info.discriminator_values + ) + field_schema["display_name"] = ", ".join(hostnames) + # Add other (optional) field info items field_schema.update( field_info.model_dump( @@ -549,8 +565,17 @@ class GraphModel(Graph, GraphMeta): for graph in [self] + self.sub_graphs: for node in graph.nodes: - # Track if this node requires credentials (credentials_optional=False means required) - node_required_map[node.id] = not node.credentials_optional + # A node's credentials are optional if either: + # 1. The node metadata says so (credentials_optional=True), or + # 2. All credential fields on the block have defaults (not required by schema) + block_required = node.block.input_schema.get_required_fields() + creds_required_by_schema = any( + fname in block_required + for fname in node.block.input_schema.get_credentials_fields() + ) + node_required_map[node.id] = ( + not node.credentials_optional and creds_required_by_schema + ) for ( field_name, @@ -776,6 +801,19 @@ class GraphModel(Graph, GraphMeta): "'credentials' and `*_credentials` are reserved" ) + # Check custom block-level validation (e.g., MCP dynamic tool arguments). + # Blocks can override get_missing_input to report additional missing fields + # beyond the standard top-level required fields. + if for_run: + credential_fields = InputSchema.get_credentials_fields() + custom_missing = InputSchema.get_missing_input(node.input_default) + for field_name in custom_missing: + if ( + field_name not in provided_inputs + and field_name not in credential_fields + ): + node_errors[node.id][field_name] = "This field is required" + # Get input schema properties and check dependencies input_fields = InputSchema.model_fields diff --git a/autogpt_platform/backend/backend/data/graph_test.py b/autogpt_platform/backend/backend/data/graph_test.py index 442c8ed4be..3cb6f24b87 100644 --- a/autogpt_platform/backend/backend/data/graph_test.py +++ b/autogpt_platform/backend/backend/data/graph_test.py @@ -462,3 +462,120 @@ def test_node_credentials_optional_with_other_metadata(): assert node.credentials_optional is True assert node.metadata["position"] == {"x": 100, "y": 200} assert node.metadata["customized_name"] == "My Custom Node" + + +# ============================================================================ +# Tests for MCP Credential Deduplication +# ============================================================================ + + +def test_mcp_credential_combine_different_servers(): + """Two MCP credential fields with different server URLs should produce + separate entries when combined (not merged into one).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_sentry = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_linear = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.linear.app/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_sentry, ("node-sentry", "credentials")), + (field_linear, ("node-linear", "credentials")), + ) + + # Should produce 2 separate credential entries + assert len(combined) == 2, ( + f"Expected 2 credential entries for 2 MCP blocks with different servers, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + # Each entry should contain the server hostname in its key + keys = list(combined.keys()) + assert any( + "mcp.sentry.dev" in k for k in keys + ), f"Expected 'mcp.sentry.dev' in one key, got {keys}" + assert any( + "mcp.linear.app" in k for k in keys + ), f"Expected 'mcp.linear.app' in one key, got {keys}" + + +def test_mcp_credential_combine_same_server(): + """Two MCP credential fields with the same server URL should be combined + into one credential entry.""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 credential entry (same server URL) + assert len(combined) == 1, ( + f"Expected 1 credential entry for 2 MCP blocks with same server, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + +def test_mcp_credential_combine_no_discriminator_values(): + """MCP credential fields without discriminator_values should be merged + into a single entry (backwards compat for blocks without server_url set).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 entry (no URL differentiation) + assert len(combined) == 1, ( + f"Expected 1 credential entry for MCP blocks without discriminator_values, " + f"got {len(combined)}: {list(combined.keys())}" + ) diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index e61f7efbd0..c9d8c5879f 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -29,6 +29,7 @@ from pydantic import ( GetCoreSchemaHandler, SecretStr, field_serializer, + model_validator, ) from pydantic_core import ( CoreSchema, @@ -502,6 +503,25 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): provider: CP type: CT + @model_validator(mode="before") + @classmethod + def _normalize_legacy_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug. + + Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"`` + instead of the plain value. Old stored credential references may have + ``provider: "ProviderName.MCP"`` instead of ``"mcp"``. + """ + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + @classmethod def allowed_providers(cls) -> tuple[ProviderName, ...] | None: return get_args(cls.model_fields["provider"].annotation) @@ -606,11 +626,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): ] = defaultdict(list) for field, key in fields: - if field.provider == frozenset([ProviderName.HTTP]): - # HTTP host-scoped credentials can have different hosts that reqires different credential sets. - # Group by host extracted from the URL + if ( + field.discriminator + and not field.discriminator_mapping + and field.discriminator_values + ): + # URL-based discrimination (e.g. HTTP host-scoped, MCP server URL): + # Each unique host gets its own credential entry. + provider_prefix = next(iter(field.provider)) + # Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP") + prefix_str = getattr(provider_prefix, "value", str(provider_prefix)) providers = frozenset( - [cast(CP, "http")] + [cast(CP, prefix_str)] + [ cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 1f76458947..caa98784c2 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -20,6 +20,7 @@ from backend.blocks import get_block from backend.blocks._base import BlockSchema from backend.blocks.agent import AgentExecutorBlock from backend.blocks.io import AgentOutputBlock +from backend.blocks.mcp.block import MCPToolBlock from backend.data import redis_client as redis from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry from backend.data.credit import UsageTransactionMetadata @@ -228,6 +229,18 @@ async def execute_node( _input_data.nodes_input_masks = nodes_input_masks _input_data.user_id = user_id input_data = _input_data.model_dump() + elif isinstance(node_block, MCPToolBlock): + _mcp_data = MCPToolBlock.Input(**node.input_default) + # Dynamic tool fields are flattened to top-level by validate_exec + # (via get_input_defaults). Collect them back into tool_arguments. + tool_schema = _mcp_data.tool_input_schema + tool_props = set(tool_schema.get("properties", {}).keys()) + merged_args = {**_mcp_data.tool_arguments} + for key in tool_props: + if key in input_data: + merged_args[key] = input_data[key] + _mcp_data.tool_arguments = merged_args + input_data = _mcp_data.model_dump() data.inputs = input_data # Execute the node @@ -264,8 +277,34 @@ async def execute_node( # Handle regular credentials fields for field_name, input_type in input_model.get_credentials_fields().items(): - credentials_meta = input_type(**input_data[field_name]) - credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id) + field_value = input_data.get(field_name) + if not field_value or ( + isinstance(field_value, dict) and not field_value.get("id") + ): + # No credentials configured — nullify so JSON schema validation + # doesn't choke on the empty default `{}`. + input_data[field_name] = None + continue # Block runs without credentials + + credentials_meta = input_type(**field_value) + # Write normalized values back so JSON schema validation also passes + # (model_validator may have fixed legacy formats like "ProviderName.MCP") + input_data[field_name] = credentials_meta.model_dump(mode="json") + try: + credentials, lock = await creds_manager.acquire( + user_id, credentials_meta.id + ) + except ValueError: + # Credential was deleted or doesn't exist. + # If the field has a default, run without credentials. + if input_model.model_fields[field_name].default is not None: + log_metadata.warning( + f"Credentials #{credentials_meta.id} not found, " + "running without (field has default)" + ) + input_data[field_name] = None + continue + raise creds_locks.append(lock) extra_exec_kwargs[field_name] = credentials diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index bb5da1e527..2b9a454061 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -260,7 +260,13 @@ async def _validate_node_input_credentials( # Track if any credential field is missing for this node has_missing_credentials = False + # A credential field is optional if the node metadata says so, or if + # the block schema declares a default for the field. + required_fields = block.input_schema.get_required_fields() + is_creds_optional = node.credentials_optional + for field_name, credentials_meta_type in credentials_fields.items(): + field_is_optional = is_creds_optional or field_name not in required_fields try: # Check nodes_input_masks first, then input_default field_value = None @@ -273,7 +279,7 @@ async def _validate_node_input_credentials( elif field_name in node.input_default: # For optional credentials, don't use input_default - treat as missing # This prevents stale credential IDs from failing validation - if node.credentials_optional: + if field_is_optional: field_value = None else: field_value = node.input_default[field_name] @@ -283,8 +289,8 @@ async def _validate_node_input_credentials( isinstance(field_value, dict) and not field_value.get("id") ): has_missing_credentials = True - # If node has credentials_optional flag, mark for skipping instead of error - if node.credentials_optional: + # If credential field is optional, skip instead of error + if field_is_optional: continue # Don't add error, will be marked for skip after loop else: credential_errors[node.id][ @@ -334,16 +340,16 @@ async def _validate_node_input_credentials( ] = "Invalid credentials: type/provider mismatch" continue - # If node has optional credentials and any are missing, mark for skipping - # But only if there are no other errors for this node + # If node has optional credentials and any are missing, allow running without. + # The executor will pass credentials=None to the block's run(). if ( has_missing_credentials - and node.credentials_optional + and is_creds_optional and node.id not in credential_errors ): - nodes_to_skip.add(node.id) logger.info( - f"Node #{node.id} will be skipped: optional credentials not configured" + f"Node #{node.id}: optional credentials not configured, " + "running without" ) return credential_errors, nodes_to_skip diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index db33249583..069086a6fd 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -495,6 +495,7 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph @@ -508,8 +509,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( nodes_input_masks=None, ) - # Node should be in nodes_to_skip, not in errors - assert mock_node.id in nodes_to_skip + # Node should NOT be in nodes_to_skip (runs without credentials) and not in errors + assert mock_node.id not in nodes_to_skip assert mock_node.id not in errors @@ -535,6 +536,7 @@ async def test_validate_node_input_credentials_required_missing_creds_error( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph diff --git a/autogpt_platform/backend/backend/integrations/credentials_store.py b/autogpt_platform/backend/backend/integrations/credentials_store.py index 384405b0c7..3e79a6c047 100644 --- a/autogpt_platform/backend/backend/integrations/credentials_store.py +++ b/autogpt_platform/backend/backend/integrations/credentials_store.py @@ -22,6 +22,27 @@ from backend.util.settings import Settings settings = Settings() + +def provider_matches(stored: str, expected: str) -> bool: + """Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug. + + On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"`` + instead of ``"mcp"``. OAuth states persisted with the buggy format need + to match when ``expected`` is the canonical value (e.g. ``"mcp"``). + """ + if stored == expected: + return True + if stored.startswith("ProviderName."): + member = stored.removeprefix("ProviderName.") + from backend.integrations.providers import ProviderName + + try: + return ProviderName[member].value == expected + except KeyError: + pass + return False + + # This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached ollama_credentials = APIKeyCredentials( id="744fdc56-071a-4761-b5a5-0af0ce10a2b5", @@ -389,7 +410,7 @@ class IntegrationCredentialsStore: self, user_id: str, provider: str ) -> list[Credentials]: credentials = await self.get_all_creds(user_id) - return [c for c in credentials if c.provider == provider] + return [c for c in credentials if provider_matches(c.provider, provider)] async def get_authorized_providers(self, user_id: str) -> list[str]: credentials = await self.get_all_creds(user_id) @@ -485,17 +506,6 @@ class IntegrationCredentialsStore: async with self.edit_user_integrations(user_id) as user_integrations: user_integrations.oauth_states.append(state) - async with await self.locked_user_integrations(user_id): - - user_integrations = await self._get_user_integrations(user_id) - oauth_states = user_integrations.oauth_states - oauth_states.append(state) - user_integrations.oauth_states = oauth_states - - await self.db_manager.update_user_integrations( - user_id=user_id, data=user_integrations - ) - return token, code_challenge def _generate_code_challenge(self) -> tuple[str, str]: @@ -521,7 +531,7 @@ class IntegrationCredentialsStore: state for state in oauth_states if secrets.compare_digest(state.token, token) - and state.provider == provider + and provider_matches(state.provider, provider) and state.expires_at > now.timestamp() ), None, diff --git a/autogpt_platform/backend/backend/integrations/creds_manager.py b/autogpt_platform/backend/backend/integrations/creds_manager.py index f2b6a9da4f..5634dd73b6 100644 --- a/autogpt_platform/backend/backend/integrations/creds_manager.py +++ b/autogpt_platform/backend/backend/integrations/creds_manager.py @@ -9,7 +9,10 @@ from redis.asyncio.lock import Lock as AsyncRedisLock from backend.data.model import Credentials, OAuth2Credentials from backend.data.redis_client import get_redis_async -from backend.integrations.credentials_store import IntegrationCredentialsStore +from backend.integrations.credentials_store import ( + IntegrationCredentialsStore, + provider_matches, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.util.exceptions import MissingConfigError @@ -137,7 +140,10 @@ class IntegrationCredentialsManager: self, user_id: str, credentials: OAuth2Credentials, lock: bool = True ) -> OAuth2Credentials: async with self._locked(user_id, credentials.id, "refresh"): - oauth_handler = await _get_provider_oauth_handler(credentials.provider) + if provider_matches(credentials.provider, ProviderName.MCP.value): + oauth_handler = create_mcp_oauth_handler(credentials) + else: + oauth_handler = await _get_provider_oauth_handler(credentials.provider) if oauth_handler.needs_refresh(credentials): logger.debug( f"Refreshing '{credentials.provider}' " @@ -236,3 +242,31 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl client_secret=client_secret, redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback", ) + + +def create_mcp_oauth_handler( + credentials: OAuth2Credentials, +) -> "BaseOAuthHandler": + """Create an MCPOAuthHandler from credential metadata for token refresh. + + MCP OAuth handlers have dynamic endpoints discovered per-server, so they + can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler + is reconstructed from metadata stored on the credential during initial auth. + """ + from backend.blocks.mcp.oauth import MCPOAuthHandler + + meta = credentials.metadata or {} + token_url = meta.get("mcp_token_url", "") + if not token_url: + raise ValueError( + f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; " + "cannot refresh tokens" + ) + return MCPOAuthHandler( + client_id=meta.get("mcp_client_id", ""), + client_secret=meta.get("mcp_client_secret", ""), + redirect_uri="", # Not needed for token refresh + authorize_url="", # Not needed for token refresh + token_url=token_url, + resource_url=meta.get("mcp_resource_url"), + ) diff --git a/autogpt_platform/backend/backend/integrations/providers.py b/autogpt_platform/backend/backend/integrations/providers.py index 8a0d6fd183..a462cd787f 100644 --- a/autogpt_platform/backend/backend/integrations/providers.py +++ b/autogpt_platform/backend/backend/integrations/providers.py @@ -30,6 +30,7 @@ class ProviderName(str, Enum): IDEOGRAM = "ideogram" JINA = "jina" LLAMA_API = "llama_api" + MCP = "mcp" MEDIUM = "medium" MEM0 = "mem0" NOTION = "notion" diff --git a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py index 99eee404b9..8fdbe10383 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py @@ -51,6 +51,21 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str): if ( creds_meta := new_node.input_default.get(creds_field_name) ) and not await get_credentials(creds_meta["id"]): + # If the credential field is optional (has a default in the + # schema, or node metadata marks it optional), clear the stale + # reference instead of blocking the save. + creds_field_optional = ( + new_node.credentials_optional + or creds_field_name not in block_input_schema.get_required_fields() + ) + if creds_field_optional: + new_node.input_default[creds_field_name] = {} + logger.warning( + f"Node #{new_node.id}: cleared stale optional " + f"credentials #{creds_meta['id']} for " + f"'{creds_field_name}'" + ) + continue raise ValueError( f"Node #{new_node.id} input '{creds_field_name}' updated with " f"non-existent credentials #{creds_meta['id']}" diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index fbd3573112..4eadc41333 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -38,6 +38,7 @@ class Flag(str, Enum): AGENT_ACTIVITY = "agent-activity" ENABLE_PLATFORM_PAYMENT = "enable-platform-payment" CHAT = "chat" + COPILOT_SDK = "copilot-sdk" def is_configured() -> bool: diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 95e5ee32f7..9470909dfc 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver): def __init__(self, ssl_hostname: str, ip_addresses: list[str]): self.ssl_hostname = ssl_hostname self.ip_addresses = ip_addresses - self._default = aiohttp.AsyncResolver() + self._default = aiohttp.ThreadedResolver() async def resolve(self, host, port=0, family=socket.AF_INET): if host == self.ssl_hostname: @@ -467,7 +467,7 @@ class Requests: resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses) ssl_context = ssl.create_default_context() connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context) - session_kwargs = {} + session_kwargs: dict = {} if connector: session_kwargs["connector"] = connector diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 48dadb88f1..c5cca87b6e 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -662,6 +662,17 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): mem0_api_key: str = Field(default="", description="Mem0 API key") elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key") + linear_api_key: str = Field( + default="", description="Linear API key for system-level operations" + ) + linear_feature_request_project_id: str = Field( + default="", + description="Linear project ID where feature requests are tracked", + ) + linear_feature_request_team_id: str = Field( + default="", + description="Linear team ID used when creating feature request issues", + ) linear_client_id: str = Field(default="", description="Linear client ID") linear_client_secret: str = Field(default="", description="Linear client secret") diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index 53b5030da6..8062457a70 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -441,14 +441,14 @@ develop = true colorama = "^0.4.6" cryptography = "^46.0" expiringdict = "^1.2.2" -fastapi = "^0.128.0" +fastapi = "^0.128.7" google-cloud-logging = "^3.13.0" -launchdarkly-server-sdk = "^9.14.1" +launchdarkly-server-sdk = "^9.15.0" pydantic = "^2.12.5" pydantic-settings = "^2.12.0" pyjwt = {version = "^2.11.0", extras = ["crypto"]} redis = "^6.2.0" -supabase = "^2.27.2" +supabase = "^2.28.0" uvicorn = "^0.40.0" [package.source] @@ -897,6 +897,29 @@ files = [ {file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"}, ] +[[package]] +name = "claude-agent-sdk" +version = "0.1.35" +description = "Python SDK for Claude Code" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"}, + {file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"}, + {file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"}, +] + +[package.dependencies] +anyio = ">=4.0.0" +mcp = ">=0.1.0" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"] + [[package]] name = "cleo" version = "2.1.0" @@ -1382,14 +1405,14 @@ tzdata = "*" [[package]] name = "fastapi" -version = "0.128.6" +version = "0.128.7" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "fastapi-0.128.6-py3-none-any.whl", hash = "sha256:bb1c1ef87d6086a7132d0ab60869d6f1ee67283b20fbf84ec0003bd335099509"}, - {file = "fastapi-0.128.6.tar.gz", hash = "sha256:0cb3946557e792d731b26a42b04912f16367e3c3135ea8290f620e234f2b604f"}, + {file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"}, + {file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"}, ] [package.dependencies] @@ -2593,6 +2616,18 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.3" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, + {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, +] + [[package]] name = "huggingface-hub" version = "1.4.1" @@ -3117,14 +3152,14 @@ urllib3 = ">=1.26.0,<3" [[package]] name = "launchdarkly-server-sdk" -version = "9.14.1" +version = "9.15.0" description = "LaunchDarkly SDK for Python" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"}, - {file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"}, + {file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"}, + {file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"}, ] [package.dependencies] @@ -3310,6 +3345,39 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mcp" +version = "1.26.0" +description = "Model Context Protocol SDK" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"}, + {file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"}, +] + +[package.dependencies] +anyio = ">=4.5" +httpx = ">=0.27.1" +httpx-sse = ">=0.4" +jsonschema = ">=4.20.0" +pydantic = ">=2.11.0,<3.0.0" +pydantic-settings = ">=2.5.2" +pyjwt = {version = ">=2.10.1", extras = ["crypto"]} +python-multipart = ">=0.0.9" +pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""} +sse-starlette = ">=1.6.1" +starlette = ">=0.27" +typing-extensions = ">=4.9.0" +typing-inspection = ">=0.4.1" +uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""} + +[package.extras] +cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"] +rich = ["rich (>=13.9.4)"] +ws = ["websockets (>=15.0.1)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -4728,14 +4796,14 @@ tests = ["coverage-conditional-plugin (>=0.9.0)", "portalocker[redis]", "pytest [[package]] name = "postgrest" -version = "2.27.3" +version = "2.28.0" description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "postgrest-2.27.3-py3-none-any.whl", hash = "sha256:ed79123af7127edd78d538bfe8351d277e45b1a36994a4dbf57ae27dde87a7b7"}, - {file = "postgrest-2.27.3.tar.gz", hash = "sha256:c2e2679addfc8eaab23197bad7ddaee6cbb4cbe8c483ebd2d2e5219543037cc3"}, + {file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"}, + {file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"}, ] [package.dependencies] @@ -5994,7 +6062,7 @@ description = "Python for Window Extensions" optional = false python-versions = "*" groups = ["main"] -markers = "platform_system == \"Windows\"" +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"}, {file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"}, @@ -6260,14 +6328,14 @@ all = ["numpy"] [[package]] name = "realtime" -version = "2.27.3" +version = "2.28.0" description = "" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "realtime-2.27.3-py3-none-any.whl", hash = "sha256:f571115f86988e33c41c895cb3fba2eaa1b693aeaede3617288f44274ca90f43"}, - {file = "realtime-2.27.3.tar.gz", hash = "sha256:02b082243107656a5ef3fb63e8e2ab4c40bc199abb45adb8a42ed63f089a1041"}, + {file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"}, + {file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"}, ] [package.dependencies] @@ -6974,6 +7042,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sse-starlette" +version = "3.2.0" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"}, + {file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"}, +] + +[package.dependencies] +anyio = ">=4.7.0" +starlette = ">=0.49.1" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + [[package]] name = "stagehand" version = "0.5.9" @@ -7024,14 +7114,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart [[package]] name = "storage3" -version = "2.27.3" +version = "2.28.0" description = "Supabase Storage client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "storage3-2.27.3-py3-none-any.whl", hash = "sha256:11a05b7da84bccabeeea12d940bca3760cf63fe6ca441868677335cfe4fdfbe0"}, - {file = "storage3-2.27.3.tar.gz", hash = "sha256:dc1a4a010cf36d5482c5cb6c1c28fc5f00e23284342b89e4ae43b5eae8501ddb"}, + {file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"}, + {file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"}, ] [package.dependencies] @@ -7091,35 +7181,35 @@ typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""} [[package]] name = "supabase" -version = "2.27.3" +version = "2.28.0" description = "Supabase client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase-2.27.3-py3-none-any.whl", hash = "sha256:082a74642fcf9954693f1ce8c251baf23e4bda26ffdbc8dcd4c99c82e60d69ff"}, - {file = "supabase-2.27.3.tar.gz", hash = "sha256:5e5a348232ac4315c1032ddd687278f0b982465471f0cbb52bca7e6a66495ff3"}, + {file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"}, + {file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"}, ] [package.dependencies] httpx = ">=0.26,<0.29" -postgrest = "2.27.3" -realtime = "2.27.3" -storage3 = "2.27.3" -supabase-auth = "2.27.3" -supabase-functions = "2.27.3" +postgrest = "2.28.0" +realtime = "2.28.0" +storage3 = "2.28.0" +supabase-auth = "2.28.0" +supabase-functions = "2.28.0" yarl = ">=1.22.0" [[package]] name = "supabase-auth" -version = "2.27.3" +version = "2.28.0" description = "Python Client Library for Supabase Auth" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_auth-2.27.3-py3-none-any.whl", hash = "sha256:82a4262eaad85383319d394dab0eea11fcf3ebd774062aef8ea3874ae2f02579"}, - {file = "supabase_auth-2.27.3.tar.gz", hash = "sha256:39894d4bc60b6f23b5cff4d0d7d4c1659e5d69563cadf014d4896f780ca8ca78"}, + {file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"}, + {file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"}, ] [package.dependencies] @@ -7129,14 +7219,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]} [[package]] name = "supabase-functions" -version = "2.27.3" +version = "2.28.0" description = "Library for Supabase Functions" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_functions-2.27.3-py3-none-any.whl", hash = "sha256:9d14a931d49ede1c6cf5fbfceb11c44061535ba1c3f310f15384964d86a83d9e"}, - {file = "supabase_functions-2.27.3.tar.gz", hash = "sha256:e954f1646da8ca6e7e16accef58d0884a5f97b25956ee98e7d4927a210ed92f9"}, + {file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"}, + {file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"}, ] [package.dependencies] @@ -8440,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "c06e96ad49388ba7a46786e9ea55ea2c1a57408e15613237b4bee40a592a12af" +content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index 317663ee98..7a112e75ca 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -16,6 +16,7 @@ anthropic = "^0.79.0" apscheduler = "^3.11.1" autogpt-libs = { path = "../autogpt_libs", develop = true } bleach = { extras = ["css"], version = "^6.2.0" } +claude-agent-sdk = "^0.1.0" click = "^8.2.0" cryptography = "^46.0" discord-py = "^2.5.2" @@ -65,7 +66,7 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal sqlalchemy = "^2.0.40" strenum = "^0.4.9" stripe = "^11.5.0" -supabase = "2.27.3" +supabase = "2.28.0" tenacity = "^9.1.4" todoist-api-python = "^2.1.7" tweepy = "^4.16.0" diff --git a/autogpt_platform/backend/test/blocks/test_jina_extract_website.py b/autogpt_platform/backend/test/blocks/test_jina_extract_website.py new file mode 100644 index 0000000000..335c43f966 --- /dev/null +++ b/autogpt_platform/backend/test/blocks/test_jina_extract_website.py @@ -0,0 +1,66 @@ +from typing import cast + +import pytest + +from backend.blocks.jina._auth import ( + TEST_CREDENTIALS, + TEST_CREDENTIALS_INPUT, + JinaCredentialsInput, +) +from backend.blocks.jina.search import ExtractWebsiteContentBlock +from backend.util.request import HTTPClientError + + +@pytest.mark.asyncio +async def test_extract_website_content_returns_content(monkeypatch): + block = ExtractWebsiteContentBlock() + input_data = block.Input( + url="https://example.com", + credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT), + raw_content=True, + ) + + async def fake_get_request(url, json=False, headers=None): + assert url == "https://example.com" + assert headers == {} + return "page content" + + monkeypatch.setattr(block, "get_request", fake_get_request) + + results = [ + output + async for output in block.run( + input_data=input_data, credentials=TEST_CREDENTIALS + ) + ] + + assert ("content", "page content") in results + assert all(key != "error" for key, _ in results) + + +@pytest.mark.asyncio +async def test_extract_website_content_handles_http_error(monkeypatch): + block = ExtractWebsiteContentBlock() + input_data = block.Input( + url="https://example.com", + credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT), + raw_content=False, + ) + + async def fake_get_request(url, json=False, headers=None): + raise HTTPClientError("HTTP 400 Error: Bad Request", 400) + + monkeypatch.setattr(block, "get_request", fake_get_request) + + results = [ + output + async for output in block.run( + input_data=input_data, credentials=TEST_CREDENTIALS + ) + ] + + assert ("content", "page content") not in results + error_messages = [value for key, value in results if key == "error"] + assert error_messages + assert "Client error (400)" in error_messages[0] + assert "https://example.com" in error_messages[0] diff --git a/autogpt_platform/backend/test/blocks/test_list_concatenation.py b/autogpt_platform/backend/test/blocks/test_list_concatenation.py new file mode 100644 index 0000000000..8cea3b60f7 --- /dev/null +++ b/autogpt_platform/backend/test/blocks/test_list_concatenation.py @@ -0,0 +1,1276 @@ +""" +Comprehensive test suite for list concatenation and manipulation blocks. + +Tests cover: +- ConcatenateListsBlock: basic concatenation, deduplication, None removal +- FlattenListBlock: nested list flattening with depth control +- InterleaveListsBlock: round-robin interleaving of multiple lists +- ZipListsBlock: zipping lists with truncation and padding +- ListDifferenceBlock: computing list differences (regular and symmetric) +- ListIntersectionBlock: finding common elements between lists +- Helper utility functions: validation, flattening, deduplication, etc. +""" + +import pytest + +from backend.blocks.data_manipulation import ( + _MAX_FLATTEN_DEPTH, + ConcatenateListsBlock, + FlattenListBlock, + InterleaveListsBlock, + ListDifferenceBlock, + ListIntersectionBlock, + ZipListsBlock, + _compute_nesting_depth, + _concatenate_lists_simple, + _deduplicate_list, + _filter_none_values, + _flatten_nested_list, + _interleave_lists, + _make_hashable, + _validate_all_lists, + _validate_list_input, +) +from backend.util.test import execute_block_test + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + + +class TestValidateListInput: + """Tests for the _validate_list_input helper.""" + + def test_valid_list_returns_none(self): + assert _validate_list_input([1, 2, 3], 0) is None + + def test_empty_list_returns_none(self): + assert _validate_list_input([], 0) is None + + def test_none_returns_none(self): + assert _validate_list_input(None, 0) is None + + def test_string_returns_error(self): + result = _validate_list_input("hello", 0) + assert result is not None + assert "str" in result + assert "index 0" in result + + def test_integer_returns_error(self): + result = _validate_list_input(42, 1) + assert result is not None + assert "int" in result + assert "index 1" in result + + def test_dict_returns_error(self): + result = _validate_list_input({"a": 1}, 2) + assert result is not None + assert "dict" in result + assert "index 2" in result + + def test_tuple_returns_error(self): + result = _validate_list_input((1, 2), 3) + assert result is not None + assert "tuple" in result + + def test_boolean_returns_error(self): + result = _validate_list_input(True, 0) + assert result is not None + assert "bool" in result + + def test_float_returns_error(self): + result = _validate_list_input(3.14, 0) + assert result is not None + assert "float" in result + + +class TestValidateAllLists: + """Tests for the _validate_all_lists helper.""" + + def test_all_valid_lists(self): + assert _validate_all_lists([[1], [2], [3]]) is None + + def test_empty_outer_list(self): + assert _validate_all_lists([]) is None + + def test_mixed_valid_and_none(self): + # None is skipped, so this should pass + assert _validate_all_lists([[1], None, [3]]) is None + + def test_invalid_item_returns_error(self): + result = _validate_all_lists([[1], "bad", [3]]) + assert result is not None + assert "index 1" in result + + def test_first_invalid_is_returned(self): + result = _validate_all_lists(["first_bad", "second_bad"]) + assert result is not None + assert "index 0" in result + + def test_all_none_passes(self): + assert _validate_all_lists([None, None, None]) is None + + +class TestConcatenateListsSimple: + """Tests for the _concatenate_lists_simple helper.""" + + def test_basic_concatenation(self): + assert _concatenate_lists_simple([[1, 2], [3, 4]]) == [1, 2, 3, 4] + + def test_empty_lists(self): + assert _concatenate_lists_simple([[], []]) == [] + + def test_single_list(self): + assert _concatenate_lists_simple([[1, 2, 3]]) == [1, 2, 3] + + def test_no_lists(self): + assert _concatenate_lists_simple([]) == [] + + def test_skip_none_values(self): + assert _concatenate_lists_simple([[1, 2], None, [3, 4]]) == [1, 2, 3, 4] # type: ignore[arg-type] + + def test_mixed_types(self): + result = _concatenate_lists_simple([[1, "a"], [True, 3.14]]) + assert result == [1, "a", True, 3.14] + + def test_nested_lists_preserved(self): + result = _concatenate_lists_simple([[[1, 2]], [[3, 4]]]) + assert result == [[1, 2], [3, 4]] + + def test_large_number_of_lists(self): + lists = [[i] for i in range(100)] + result = _concatenate_lists_simple(lists) + assert result == list(range(100)) + + +class TestFlattenNestedList: + """Tests for the _flatten_nested_list helper.""" + + def test_already_flat(self): + assert _flatten_nested_list([1, 2, 3]) == [1, 2, 3] + + def test_one_level_nesting(self): + assert _flatten_nested_list([[1, 2], [3, 4]]) == [1, 2, 3, 4] + + def test_deep_nesting(self): + assert _flatten_nested_list([1, [2, [3, [4, [5]]]]]) == [1, 2, 3, 4, 5] + + def test_empty_list(self): + assert _flatten_nested_list([]) == [] + + def test_mixed_nesting(self): + assert _flatten_nested_list([1, [2, 3], 4, [5, [6]]]) == [1, 2, 3, 4, 5, 6] + + def test_max_depth_zero(self): + # max_depth=0 means no flattening at all + result = _flatten_nested_list([[1, 2], [3, 4]], max_depth=0) + assert result == [[1, 2], [3, 4]] + + def test_max_depth_one(self): + result = _flatten_nested_list([[1, [2, 3]], [4, [5]]], max_depth=1) + assert result == [1, [2, 3], 4, [5]] + + def test_max_depth_two(self): + result = _flatten_nested_list([[[1, 2], [3]], [[4, [5]]]], max_depth=2) + assert result == [1, 2, 3, 4, [5]] + + def test_unlimited_depth(self): + deeply_nested = [[[[[[[1]]]]]]] + assert _flatten_nested_list(deeply_nested, max_depth=-1) == [1] + + def test_preserves_non_list_iterables(self): + result = _flatten_nested_list(["hello", [1, 2]]) + assert result == ["hello", 1, 2] + + def test_preserves_dicts(self): + result = _flatten_nested_list([{"a": 1}, [{"b": 2}]]) + assert result == [{"a": 1}, {"b": 2}] + + def test_excessive_depth_raises_recursion_error(self): + """Deeply nested lists beyond 1000 levels should raise RecursionError.""" + # Build a list nested 1100 levels deep + nested = [42] + for _ in range(1100): + nested = [nested] + with pytest.raises(RecursionError, match="maximum.*depth"): + _flatten_nested_list(nested, max_depth=-1) + + +class TestDeduplicateList: + """Tests for the _deduplicate_list helper.""" + + def test_no_duplicates(self): + assert _deduplicate_list([1, 2, 3]) == [1, 2, 3] + + def test_with_duplicates(self): + assert _deduplicate_list([1, 2, 2, 3, 3, 3]) == [1, 2, 3] + + def test_all_duplicates(self): + assert _deduplicate_list([1, 1, 1]) == [1] + + def test_empty_list(self): + assert _deduplicate_list([]) == [] + + def test_preserves_order(self): + result = _deduplicate_list([3, 1, 2, 1, 3]) + assert result == [3, 1, 2] + + def test_string_duplicates(self): + assert _deduplicate_list(["a", "b", "a", "c"]) == ["a", "b", "c"] + + def test_mixed_types(self): + result = _deduplicate_list([1, "1", 1, "1"]) + assert result == [1, "1"] + + def test_dict_duplicates(self): + result = _deduplicate_list([{"a": 1}, {"a": 1}, {"b": 2}]) + assert result == [{"a": 1}, {"b": 2}] + + def test_list_duplicates(self): + result = _deduplicate_list([[1, 2], [1, 2], [3, 4]]) + assert result == [[1, 2], [3, 4]] + + def test_none_duplicates(self): + result = _deduplicate_list([None, 1, None, 2]) + assert result == [None, 1, 2] + + def test_single_element(self): + assert _deduplicate_list([42]) == [42] + + +class TestMakeHashable: + """Tests for the _make_hashable helper.""" + + def test_integer(self): + assert _make_hashable(42) == 42 + + def test_string(self): + assert _make_hashable("hello") == "hello" + + def test_none(self): + assert _make_hashable(None) is None + + def test_dict_returns_tuple(self): + result = _make_hashable({"a": 1}) + assert isinstance(result, tuple) + # Should be hashable + hash(result) + + def test_list_returns_tuple(self): + result = _make_hashable([1, 2, 3]) + assert result == (1, 2, 3) + + def test_same_dict_same_hash(self): + assert _make_hashable({"a": 1, "b": 2}) == _make_hashable({"a": 1, "b": 2}) + + def test_different_dict_different_hash(self): + assert _make_hashable({"a": 1}) != _make_hashable({"a": 2}) + + def test_dict_key_order_independent(self): + """Dicts with same keys in different insertion order produce same result.""" + d1 = {"b": 2, "a": 1} + d2 = {"a": 1, "b": 2} + assert _make_hashable(d1) == _make_hashable(d2) + + def test_tuple_hashable(self): + result = _make_hashable((1, 2, 3)) + assert result == (1, 2, 3) + hash(result) + + def test_boolean(self): + result = _make_hashable(True) + assert result is True + + def test_float(self): + result = _make_hashable(3.14) + assert result == 3.14 + + +class TestFilterNoneValues: + """Tests for the _filter_none_values helper.""" + + def test_removes_none(self): + assert _filter_none_values([1, None, 2, None, 3]) == [1, 2, 3] + + def test_no_none(self): + assert _filter_none_values([1, 2, 3]) == [1, 2, 3] + + def test_all_none(self): + assert _filter_none_values([None, None, None]) == [] + + def test_empty_list(self): + assert _filter_none_values([]) == [] + + def test_preserves_falsy_values(self): + assert _filter_none_values([0, False, "", None, []]) == [0, False, "", []] + + +class TestComputeNestingDepth: + """Tests for the _compute_nesting_depth helper.""" + + def test_flat_list(self): + assert _compute_nesting_depth([1, 2, 3]) == 1 + + def test_one_level(self): + assert _compute_nesting_depth([[1, 2], [3, 4]]) == 2 + + def test_deep_nesting(self): + assert _compute_nesting_depth([[[[]]]]) == 4 + + def test_mixed_depth(self): + depth = _compute_nesting_depth([1, [2, [3]]]) + assert depth == 3 + + def test_empty_list(self): + assert _compute_nesting_depth([]) == 1 + + def test_non_list(self): + assert _compute_nesting_depth(42) == 0 + + def test_string_not_recursed(self): + # Strings should not be treated as nested lists + assert _compute_nesting_depth(["hello"]) == 1 + + +class TestInterleaveListsHelper: + """Tests for the _interleave_lists helper.""" + + def test_equal_length_lists(self): + result = _interleave_lists([[1, 2, 3], ["a", "b", "c"]]) + assert result == [1, "a", 2, "b", 3, "c"] + + def test_unequal_length_lists(self): + result = _interleave_lists([[1, 2, 3], ["a"]]) + assert result == [1, "a", 2, 3] + + def test_empty_input(self): + assert _interleave_lists([]) == [] + + def test_single_list(self): + assert _interleave_lists([[1, 2, 3]]) == [1, 2, 3] + + def test_three_lists(self): + result = _interleave_lists([[1], [2], [3]]) + assert result == [1, 2, 3] + + def test_with_none_list(self): + result = _interleave_lists([[1, 2], None, [3, 4]]) # type: ignore[arg-type] + assert result == [1, 3, 2, 4] + + def test_all_empty_lists(self): + assert _interleave_lists([[], [], []]) == [] + + def test_all_none_lists(self): + """All-None inputs should return empty list, not crash.""" + assert _interleave_lists([None, None, None]) == [] # type: ignore[arg-type] + + +class TestComputeNestingDepthEdgeCases: + """Tests for _compute_nesting_depth with deeply nested input.""" + + def test_deeply_nested_does_not_crash(self): + """Deeply nested lists beyond 1000 levels should not raise RecursionError.""" + nested = [42] + for _ in range(1100): + nested = [nested] + # Should return a depth value without crashing + depth = _compute_nesting_depth(nested) + assert depth >= _MAX_FLATTEN_DEPTH + + +class TestMakeHashableMixedKeys: + """Tests for _make_hashable with mixed-type dict keys.""" + + def test_mixed_type_dict_keys(self): + """Dicts with mixed-type keys (int and str) should not crash sorted().""" + d = {1: "one", "two": 2} + result = _make_hashable(d) + assert isinstance(result, tuple) + hash(result) # Should be hashable without error + + def test_mixed_type_keys_deterministic(self): + """Same dict with mixed keys produces same result.""" + d1 = {1: "a", "b": 2} + d2 = {1: "a", "b": 2} + assert _make_hashable(d1) == _make_hashable(d2) + + +class TestZipListsNoneHandling: + """Tests for ZipListsBlock with None values in input.""" + + def setup_method(self): + self.block = ZipListsBlock() + + def test_zip_truncate_with_none(self): + """_zip_truncate should handle None values in input lists.""" + result = self.block._zip_truncate([[1, 2], None, [3, 4]]) # type: ignore[arg-type] + assert result == [[1, 3], [2, 4]] + + def test_zip_pad_with_none(self): + """_zip_pad should handle None values in input lists.""" + result = self.block._zip_pad([[1, 2, 3], None, ["a"]], fill_value="X") # type: ignore[arg-type] + assert result == [[1, "a"], [2, "X"], [3, "X"]] + + def test_zip_truncate_all_none(self): + """All-None inputs should return empty list.""" + result = self.block._zip_truncate([None, None]) # type: ignore[arg-type] + assert result == [] + + def test_zip_pad_all_none(self): + """All-None inputs should return empty list.""" + result = self.block._zip_pad([None, None], fill_value=0) # type: ignore[arg-type] + assert result == [] + + +# ============================================================================= +# Block Built-in Tests (using test_input/test_output) +# ============================================================================= + + +class TestConcatenateListsBlockBuiltin: + """Run the built-in test_input/test_output tests for ConcatenateListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ConcatenateListsBlock() + await execute_block_test(block) + + +class TestFlattenListBlockBuiltin: + """Run the built-in test_input/test_output tests for FlattenListBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = FlattenListBlock() + await execute_block_test(block) + + +class TestInterleaveListsBlockBuiltin: + """Run the built-in test_input/test_output tests for InterleaveListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = InterleaveListsBlock() + await execute_block_test(block) + + +class TestZipListsBlockBuiltin: + """Run the built-in test_input/test_output tests for ZipListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ZipListsBlock() + await execute_block_test(block) + + +class TestListDifferenceBlockBuiltin: + """Run the built-in test_input/test_output tests for ListDifferenceBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ListDifferenceBlock() + await execute_block_test(block) + + +class TestListIntersectionBlockBuiltin: + """Run the built-in test_input/test_output tests for ListIntersectionBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ListIntersectionBlock() + await execute_block_test(block) + + +# ============================================================================= +# ConcatenateListsBlock Manual Tests +# ============================================================================= + + +class TestConcatenateListsBlockManual: + """Manual test cases for ConcatenateListsBlock edge cases.""" + + def setup_method(self): + self.block = ConcatenateListsBlock() + + @pytest.mark.asyncio + async def test_two_lists(self): + """Test basic two-list concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1, 2], [3, 4]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4] + assert results["length"] == 4 + + @pytest.mark.asyncio + async def test_three_lists(self): + """Test three-list concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1], [2], [3]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_five_lists(self): + """Test concatenation of five lists.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1], [2], [3], [4], [5]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4, 5] + assert results["length"] == 5 + + @pytest.mark.asyncio + async def test_empty_lists_only(self): + """Test concatenation of only empty lists.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[], [], []]) + ): + results[name] = value + assert results["concatenated_list"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_mixed_types_in_lists(self): + """Test concatenation with mixed types.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, "a"], [True, 3.14], [None, {"key": "val"}]] + ) + ): + results[name] = value + assert results["concatenated_list"] == [ + 1, + "a", + True, + 3.14, + None, + {"key": "val"}, + ] + + @pytest.mark.asyncio + async def test_deduplication_enabled(self): + """Test deduplication removes duplicates.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, 2, 3], [2, 3, 4], [3, 4, 5]], + deduplicate=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_deduplication_preserves_order(self): + """Test that deduplication preserves first-occurrence order.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[3, 1, 2], [2, 4, 1]], + deduplicate=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [3, 1, 2, 4] + + @pytest.mark.asyncio + async def test_remove_none_enabled(self): + """Test None removal from concatenated results.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, None], [None, 2], [3, None]], + remove_none=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_dedup_and_remove_none_combined(self): + """Test both deduplication and None removal together.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, None, 2], [2, None, 3]], + deduplicate=True, + remove_none=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_nested_lists_preserved(self): + """Test that nested lists are not flattened during concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[[1, 2]], [[3, 4]]]) + ): + results[name] = value + assert results["concatenated_list"] == [[1, 2], [3, 4]] + + @pytest.mark.asyncio + async def test_large_lists(self): + """Test concatenation of large lists.""" + list_a = list(range(1000)) + list_b = list(range(1000, 2000)) + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[list_a, list_b]) + ): + results[name] = value + assert results["concatenated_list"] == list(range(2000)) + assert results["length"] == 2000 + + @pytest.mark.asyncio + async def test_single_list_input(self): + """Test concatenation with a single list.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1, 2, 3]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + @pytest.mark.asyncio + async def test_block_category(self): + """Test that the block has the correct category.""" + from backend.blocks._base import BlockCategory + + assert BlockCategory.BASIC in self.block.categories + + +# ============================================================================= +# FlattenListBlock Manual Tests +# ============================================================================= + + +class TestFlattenListBlockManual: + """Manual test cases for FlattenListBlock.""" + + def setup_method(self): + self.block = FlattenListBlock() + + @pytest.mark.asyncio + async def test_simple_flatten(self): + """Test flattening a simple nested list.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[[1, 2], [3, 4]]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4] + assert results["length"] == 4 + + @pytest.mark.asyncio + async def test_deeply_nested(self): + """Test flattening a deeply nested structure.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, [2, [3, [4, [5]]]]]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_partial_flatten(self): + """Test flattening with max_depth=1.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input( + nested_list=[[1, [2, 3]], [4, [5]]], + max_depth=1, + ) + ): + results[name] = value + assert results["flattened_list"] == [1, [2, 3], 4, [5]] + + @pytest.mark.asyncio + async def test_already_flat_list(self): + """Test flattening an already flat list.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, 2, 3, 4]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_empty_nested_lists(self): + """Test flattening with empty nested lists.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[[], [1], [], [2], []]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2] + + @pytest.mark.asyncio + async def test_mixed_types_preserved(self): + """Test that non-list types are preserved during flattening.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=["hello", [1, {"a": 1}], [True]]) + ): + results[name] = value + assert results["flattened_list"] == ["hello", 1, {"a": 1}, True] + + @pytest.mark.asyncio + async def test_original_depth_reported(self): + """Test that original nesting depth is correctly reported.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, [2, [3]]]) + ): + results[name] = value + assert results["original_depth"] == 3 + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# InterleaveListsBlock Manual Tests +# ============================================================================= + + +class TestInterleaveListsBlockManual: + """Manual test cases for InterleaveListsBlock.""" + + def setup_method(self): + self.block = InterleaveListsBlock() + + @pytest.mark.asyncio + async def test_equal_length_interleave(self): + """Test interleaving two equal-length lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2, 3], ["a", "b", "c"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", 2, "b", 3, "c"] + + @pytest.mark.asyncio + async def test_unequal_length_interleave(self): + """Test interleaving lists of different lengths.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2, 3, 4], ["a", "b"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", 2, "b", 3, 4] + + @pytest.mark.asyncio + async def test_three_lists_interleave(self): + """Test interleaving three lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2], ["a", "b"], ["x", "y"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", "x", 2, "b", "y"] + + @pytest.mark.asyncio + async def test_single_element_lists(self): + """Test interleaving single-element lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1], [2], [3], [4]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ZipListsBlock Manual Tests +# ============================================================================= + + +class TestZipListsBlockManual: + """Manual test cases for ZipListsBlock.""" + + def setup_method(self): + self.block = ZipListsBlock() + + @pytest.mark.asyncio + async def test_basic_zip(self): + """Test basic zipping of two lists.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2, 3], ["a", "b", "c"]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "b"], [3, "c"]] + + @pytest.mark.asyncio + async def test_truncate_to_shortest(self): + """Test that default behavior truncates to shortest list.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2, 3], ["a", "b"]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "b"]] + assert results["length"] == 2 + + @pytest.mark.asyncio + async def test_pad_to_longest(self): + """Test padding shorter lists with fill value.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input( + lists=[[1, 2, 3], ["a"]], + pad_to_longest=True, + fill_value="X", + ) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "X"], [3, "X"]] + + @pytest.mark.asyncio + async def test_pad_with_none(self): + """Test padding with None (default fill value).""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input( + lists=[[1, 2], ["a"]], + pad_to_longest=True, + ) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, None]] + + @pytest.mark.asyncio + async def test_three_lists_zip(self): + """Test zipping three lists.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2], ["a", "b"], [True, False]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a", True], [2, "b", False]] + + @pytest.mark.asyncio + async def test_empty_lists_zip(self): + """Test zipping empty input.""" + results = {} + async for name, value in self.block.run(ZipListsBlock.Input(lists=[])): + results[name] = value + assert results["zipped_list"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ListDifferenceBlock Manual Tests +# ============================================================================= + + +class TestListDifferenceBlockManual: + """Manual test cases for ListDifferenceBlock.""" + + def setup_method(self): + self.block = ListDifferenceBlock() + + @pytest.mark.asyncio + async def test_basic_difference(self): + """Test basic set difference.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3, 4, 5], + list_b=[3, 4, 5, 6, 7], + ) + ): + results[name] = value + assert results["difference"] == [1, 2] + + @pytest.mark.asyncio + async def test_symmetric_difference(self): + """Test symmetric difference.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[2, 3, 4], + symmetric=True, + ) + ): + results[name] = value + assert results["difference"] == [1, 4] + + @pytest.mark.asyncio + async def test_no_difference(self): + """Test when lists are identical.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[1, 2, 3], + ) + ): + results[name] = value + assert results["difference"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_complete_difference(self): + """Test when lists share no elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[4, 5, 6], + ) + ): + results[name] = value + assert results["difference"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_empty_list_a(self): + """Test with empty list_a.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input(list_a=[], list_b=[1, 2, 3]) + ): + results[name] = value + assert results["difference"] == [] + + @pytest.mark.asyncio + async def test_empty_list_b(self): + """Test with empty list_b.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input(list_a=[1, 2, 3], list_b=[]) + ): + results[name] = value + assert results["difference"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_string_difference(self): + """Test difference with string elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=["apple", "banana", "cherry"], + list_b=["banana", "date"], + ) + ): + results[name] = value + assert results["difference"] == ["apple", "cherry"] + + @pytest.mark.asyncio + async def test_dict_difference(self): + """Test difference with dictionary elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[{"a": 1}, {"b": 2}, {"c": 3}], + list_b=[{"b": 2}], + ) + ): + results[name] = value + assert results["difference"] == [{"a": 1}, {"c": 3}] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ListIntersectionBlock Manual Tests +# ============================================================================= + + +class TestListIntersectionBlockManual: + """Manual test cases for ListIntersectionBlock.""" + + def setup_method(self): + self.block = ListIntersectionBlock() + + @pytest.mark.asyncio + async def test_basic_intersection(self): + """Test basic intersection.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3, 4, 5], + list_b=[3, 4, 5, 6, 7], + ) + ): + results[name] = value + assert results["intersection"] == [3, 4, 5] + assert results["length"] == 3 + + @pytest.mark.asyncio + async def test_no_intersection(self): + """Test when lists share no elements.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3], + list_b=[4, 5, 6], + ) + ): + results[name] = value + assert results["intersection"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_identical_lists(self): + """Test intersection of identical lists.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3], + list_b=[1, 2, 3], + ) + ): + results[name] = value + assert results["intersection"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_preserves_order_from_list_a(self): + """Test that intersection preserves order from list_a.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[5, 3, 1], + list_b=[1, 3, 5], + ) + ): + results[name] = value + assert results["intersection"] == [5, 3, 1] + + @pytest.mark.asyncio + async def test_empty_list_a(self): + """Test with empty list_a.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input(list_a=[], list_b=[1, 2, 3]) + ): + results[name] = value + assert results["intersection"] == [] + + @pytest.mark.asyncio + async def test_empty_list_b(self): + """Test with empty list_b.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input(list_a=[1, 2, 3], list_b=[]) + ): + results[name] = value + assert results["intersection"] == [] + + @pytest.mark.asyncio + async def test_string_intersection(self): + """Test intersection with string elements.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=["apple", "banana", "cherry"], + list_b=["banana", "cherry", "date"], + ) + ): + results[name] = value + assert results["intersection"] == ["banana", "cherry"] + + @pytest.mark.asyncio + async def test_deduplication_in_intersection(self): + """Test that duplicates in input don't cause duplicate results.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 1, 2, 2, 3], + list_b=[1, 2], + ) + ): + results[name] = value + assert results["intersection"] == [1, 2] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# Block Method Tests +# ============================================================================= + + +class TestConcatenateListsBlockMethods: + """Tests for internal methods of ConcatenateListsBlock.""" + + def setup_method(self): + self.block = ConcatenateListsBlock() + + def test_validate_inputs_valid(self): + assert self.block._validate_inputs([[1], [2]]) is None + + def test_validate_inputs_invalid(self): + result = self.block._validate_inputs([[1], "bad"]) + assert result is not None + + def test_perform_concatenation(self): + result = self.block._perform_concatenation([[1, 2], [3, 4]]) + assert result == [1, 2, 3, 4] + + def test_apply_deduplication(self): + result = self.block._apply_deduplication([1, 2, 2, 3]) + assert result == [1, 2, 3] + + def test_apply_none_removal(self): + result = self.block._apply_none_removal([1, None, 2]) + assert result == [1, 2] + + def test_post_process_all_options(self): + result = self.block._post_process( + [1, None, 2, None, 2], deduplicate=True, remove_none=True + ) + assert result == [1, 2] + + def test_post_process_no_options(self): + result = self.block._post_process( + [1, None, 2, None, 2], deduplicate=False, remove_none=False + ) + assert result == [1, None, 2, None, 2] + + +class TestFlattenListBlockMethods: + """Tests for internal methods of FlattenListBlock.""" + + def setup_method(self): + self.block = FlattenListBlock() + + def test_compute_depth_flat(self): + assert self.block._compute_depth([1, 2, 3]) == 1 + + def test_compute_depth_nested(self): + assert self.block._compute_depth([[1, [2]]]) == 3 + + def test_flatten_unlimited(self): + result = self.block._flatten([1, [2, [3]]], max_depth=-1) + assert result == [1, 2, 3] + + def test_flatten_limited(self): + result = self.block._flatten([1, [2, [3]]], max_depth=1) + assert result == [1, 2, [3]] + + def test_validate_max_depth_valid(self): + assert self.block._validate_max_depth(-1) is None + assert self.block._validate_max_depth(0) is None + assert self.block._validate_max_depth(5) is None + + def test_validate_max_depth_invalid(self): + result = self.block._validate_max_depth(-2) + assert result is not None + + +class TestZipListsBlockMethods: + """Tests for internal methods of ZipListsBlock.""" + + def setup_method(self): + self.block = ZipListsBlock() + + def test_zip_truncate(self): + result = self.block._zip_truncate([[1, 2, 3], ["a", "b"]]) + assert result == [[1, "a"], [2, "b"]] + + def test_zip_pad(self): + result = self.block._zip_pad([[1, 2, 3], ["a"]], fill_value="X") + assert result == [[1, "a"], [2, "X"], [3, "X"]] + + def test_zip_pad_empty(self): + result = self.block._zip_pad([], fill_value=None) + assert result == [] + + def test_validate_inputs(self): + assert self.block._validate_inputs([[1], [2]]) is None + result = self.block._validate_inputs([[1], "bad"]) + assert result is not None + + +class TestListDifferenceBlockMethods: + """Tests for internal methods of ListDifferenceBlock.""" + + def setup_method(self): + self.block = ListDifferenceBlock() + + def test_compute_difference(self): + result = self.block._compute_difference([1, 2, 3], [2, 3, 4]) + assert result == [1] + + def test_compute_symmetric_difference(self): + result = self.block._compute_symmetric_difference([1, 2, 3], [2, 3, 4]) + assert result == [1, 4] + + def test_compute_difference_empty(self): + result = self.block._compute_difference([], [1, 2]) + assert result == [] + + def test_compute_symmetric_difference_identical(self): + result = self.block._compute_symmetric_difference([1, 2], [1, 2]) + assert result == [] + + +class TestListIntersectionBlockMethods: + """Tests for internal methods of ListIntersectionBlock.""" + + def setup_method(self): + self.block = ListIntersectionBlock() + + def test_compute_intersection(self): + result = self.block._compute_intersection([1, 2, 3], [2, 3, 4]) + assert result == [2, 3] + + def test_compute_intersection_empty(self): + result = self.block._compute_intersection([], [1, 2]) + assert result == [] + + def test_compute_intersection_no_overlap(self): + result = self.block._compute_intersection([1, 2], [3, 4]) + assert result == [] diff --git a/autogpt_platform/backend/test/chat/__init__.py b/autogpt_platform/backend/test/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/test/chat/test_security_hooks.py b/autogpt_platform/backend/test/chat/test_security_hooks.py new file mode 100644 index 0000000000..f10a90871b --- /dev/null +++ b/autogpt_platform/backend/test/chat/test_security_hooks.py @@ -0,0 +1,133 @@ +"""Tests for SDK security hooks — workspace paths, tool access, and deny messages. + +These are pure unit tests with no external dependencies (no SDK, no DB, no server). +They validate that the security hooks correctly block unauthorized paths, +tool access, and dangerous input patterns. + +Note: Bash command validation was removed — the SDK built-in Bash tool is not in +allowed_tools, and the bash_exec MCP tool has kernel-level network isolation +(unshare --net) making command-level parsing unnecessary. +""" + +from backend.api.features.chat.sdk.security_hooks import ( + _validate_tool_access, + _validate_workspace_path, +) + +SDK_CWD = "/tmp/copilot-test-session" + + +def _is_denied(result: dict) -> bool: + hook = result.get("hookSpecificOutput", {}) + return hook.get("permissionDecision") == "deny" + + +def _reason(result: dict) -> str: + return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "") + + +# ============================================================ +# Workspace path validation (Read, Write, Edit, etc.) +# ============================================================ + + +class TestWorkspacePathValidation: + def test_path_in_workspace(self): + result = _validate_workspace_path( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD + ) + assert not _is_denied(result) + + def test_path_outside_workspace(self): + result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) + assert _is_denied(result) + + def test_tool_results_allowed(self): + result = _validate_workspace_path( + "Read", + {"file_path": "~/.claude/projects/abc/tool-results/out.txt"}, + SDK_CWD, + ) + assert not _is_denied(result) + + def test_claude_settings_blocked(self): + result = _validate_workspace_path( + "Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD + ) + assert _is_denied(result) + + def test_claude_projects_without_tool_results(self): + result = _validate_workspace_path( + "Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD + ) + assert _is_denied(result) + + def test_no_path_allowed(self): + """Glob/Grep without path defaults to cwd — should be allowed.""" + result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD) + assert not _is_denied(result) + + def test_path_traversal_with_dotdot(self): + result = _validate_workspace_path( + "Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD + ) + assert _is_denied(result) + + +# ============================================================ +# Tool access validation +# ============================================================ + + +class TestToolAccessValidation: + def test_blocked_tools(self): + for tool in ("bash", "shell", "exec", "terminal", "command"): + result = _validate_tool_access(tool, {}) + assert _is_denied(result), f"Tool '{tool}' should be blocked" + + def test_bash_builtin_blocked(self): + """SDK built-in Bash (capital) is blocked as defence-in-depth.""" + result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD) + assert _is_denied(result) + assert "Bash" in _reason(result) + + def test_workspace_tools_delegate(self): + result = _validate_tool_access( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD + ) + assert not _is_denied(result) + + def test_dangerous_pattern_blocked(self): + result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"}) + assert _is_denied(result) + + def test_safe_unknown_tool_allowed(self): + result = _validate_tool_access("SomeSafeTool", {"data": "hello world"}) + assert not _is_denied(result) + + +# ============================================================ +# Deny message quality (ntindle feedback) +# ============================================================ + + +class TestDenyMessageClarity: + """Deny messages must include [SECURITY] and 'cannot be bypassed' + so the model knows the restriction is enforced, not a suggestion.""" + + def test_blocked_tool_message(self): + reason = _reason(_validate_tool_access("bash", {})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + def test_bash_builtin_blocked_message(self): + reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + def test_workspace_path_message(self): + reason = _reason( + _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) + ) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason diff --git a/autogpt_platform/backend/test/chat/test_transcript.py b/autogpt_platform/backend/test/chat/test_transcript.py new file mode 100644 index 0000000000..71b1fad81f --- /dev/null +++ b/autogpt_platform/backend/test/chat/test_transcript.py @@ -0,0 +1,255 @@ +"""Unit tests for JSONL transcript management utilities.""" + +import json +import os + +from backend.api.features.chat.sdk.transcript import ( + STRIPPABLE_TYPES, + read_transcript_file, + strip_progress_entries, + validate_transcript, + write_transcript_to_tempfile, +) + + +def _make_jsonl(*entries: dict) -> str: + return "\n".join(json.dumps(e) for e in entries) + "\n" + + +# --- Fixtures --- + + +METADATA_LINE = {"type": "queue-operation", "subtype": "create"} +FILE_HISTORY = {"type": "file-history-snapshot", "files": []} +USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}} +ASST_MSG = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": {"role": "assistant", "content": "hello"}, +} +PROGRESS_ENTRY = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress", "stdout": "running..."}, +} + +VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG) + + +# --- read_transcript_file --- + + +class TestReadTranscriptFile: + def test_returns_content_for_valid_file(self, tmp_path): + path = tmp_path / "session.jsonl" + path.write_text(VALID_TRANSCRIPT) + result = read_transcript_file(str(path)) + assert result is not None + assert "user" in result + + def test_returns_none_for_missing_file(self): + assert read_transcript_file("/nonexistent/path.jsonl") is None + + def test_returns_none_for_empty_path(self): + assert read_transcript_file("") is None + + def test_returns_none_for_empty_file(self, tmp_path): + path = tmp_path / "empty.jsonl" + path.write_text("") + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_metadata_only(self, tmp_path): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + path = tmp_path / "meta.jsonl" + path.write_text(content) + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_invalid_json(self, tmp_path): + path = tmp_path / "bad.jsonl" + path.write_text("not json\n{}\n{}\n") + assert read_transcript_file(str(path)) is None + + def test_no_size_limit(self, tmp_path): + """Large files are accepted — bucket storage has no size limit.""" + big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000} + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG) + path = tmp_path / "big.jsonl" + path.write_text(content) + result = read_transcript_file(str(path)) + assert result is not None + + +# --- write_transcript_to_tempfile --- + + +class TestWriteTranscriptToTempfile: + """Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check.""" + + def test_writes_file_and_returns_path(self): + cwd = "/tmp/copilot-test-write" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "sess-1234-abcd", cwd + ) + assert result is not None + assert os.path.isfile(result) + assert result.endswith(".jsonl") + with open(result) as f: + assert f.read() == VALID_TRANSCRIPT + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_creates_parent_directory(self): + cwd = "/tmp/copilot-test-mkdir" + try: + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is not None + assert os.path.isdir(cwd) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_uses_session_id_prefix(self): + cwd = "/tmp/copilot-test-prefix" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "abcdef12-rest", cwd + ) + assert result is not None + assert "abcdef12" in os.path.basename(result) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_rejects_cwd_outside_sandbox(self, tmp_path): + cwd = str(tmp_path / "not-copilot") + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is None + + +# --- validate_transcript --- + + +class TestValidateTranscript: + def test_valid_transcript(self): + assert validate_transcript(VALID_TRANSCRIPT) is True + + def test_none_content(self): + assert validate_transcript(None) is False + + def test_empty_content(self): + assert validate_transcript("") is False + + def test_metadata_only(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + assert validate_transcript(content) is False + + def test_user_only_no_assistant(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG) + assert validate_transcript(content) is False + + def test_assistant_only_no_user(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG) + assert validate_transcript(content) is False + + def test_invalid_json_returns_false(self): + assert validate_transcript("not json\n{}\n{}\n") is False + + +# --- strip_progress_entries --- + + +class TestStripProgressEntries: + def test_strips_all_strippable_types(self): + """All STRIPPABLE_TYPES are removed from the output.""" + entries = [ + USER_MSG, + {"type": "progress", "uuid": "p1", "parentUuid": "u1"}, + {"type": "file-history-snapshot", "files": []}, + {"type": "queue-operation", "subtype": "create"}, + {"type": "summary", "text": "..."}, + {"type": "pr-link", "url": "..."}, + ASST_MSG, + ] + result = strip_progress_entries(_make_jsonl(*entries)) + result_types = {json.loads(line)["type"] for line in result.strip().split("\n")} + assert result_types == {"user", "assistant"} + for stype in STRIPPABLE_TYPES: + assert stype not in result_types + + def test_reparents_children_of_stripped_entries(self): + """An assistant message whose parent is a progress entry gets reparented.""" + progress = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress"}, + } + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p1", # Points to progress + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, progress, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + # Should be reparented to u1 (the user message) + assert asst_entry["parentUuid"] == "u1" + + def test_reparents_through_chain(self): + """Reparenting walks through multiple stripped entries.""" + p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"} + p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"} + p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"} + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p3", # 3 levels deep + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, p1, p2, p3, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + assert asst_entry["parentUuid"] == "u1" + + def test_preserves_non_strippable_entries(self): + """User, assistant, and system entries are preserved.""" + system = {"type": "system", "uuid": "s1", "message": "prompt"} + content = _make_jsonl(system, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + assert result_types == ["system", "user", "assistant"] + + def test_empty_input(self): + result = strip_progress_entries("") + # Should return just a newline (empty content stripped) + assert result.strip() == "" + + def test_no_strippable_entries(self): + """When there's nothing to strip, output matches input structure.""" + content = _make_jsonl(USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_lines = result.strip().split("\n") + assert len(result_lines) == 2 + + def test_handles_entries_without_uuid(self): + """Entries without uuid field are handled gracefully.""" + no_uuid = {"type": "queue-operation", "subtype": "create"} + content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + # queue-operation is strippable + assert "queue-operation" not in result_types + assert "user" in result_types + assert "assistant" in result_types diff --git a/autogpt_platform/docker-compose.platform.yml b/autogpt_platform/docker-compose.platform.yml index de6ecfd612..bab92d4693 100644 --- a/autogpt_platform/docker-compose.platform.yml +++ b/autogpt_platform/docker-compose.platform.yml @@ -37,7 +37,7 @@ services: context: ../ dockerfile: autogpt_platform/backend/Dockerfile target: migrate - command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"] + command: ["sh", "-c", "prisma generate && python3 gen_prisma_types_stub.py && prisma migrate deploy"] develop: watch: - path: ./ @@ -56,7 +56,7 @@ services: test: [ "CMD-SHELL", - "poetry run prisma migrate status | grep -q 'No pending migrations' || exit 1", + "prisma migrate status | grep -q 'No pending migrations' || exit 1", ] interval: 30s timeout: 10s diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts new file mode 100644 index 0000000000..326f42e049 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts @@ -0,0 +1,96 @@ +import { NextResponse } from "next/server"; + +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + +// MCP-specific OAuth callback route. +// +// Unlike the generic oauth_callback which relies on window.opener.postMessage, +// this route uses BroadcastChannel as the PRIMARY communication method. +// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost) +// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers. +// +// BroadcastChannel works across all same-origin tabs/popups regardless of opener. +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const code = searchParams.get("code"); + const state = searchParams.get("state"); + + const success = Boolean(code && state); + const message = success + ? { success: true, code, state } + : { + success: false, + message: `Missing parameters: ${searchParams.toString()}`, + }; + + return new NextResponse( + ` + + MCP Sign-in + +
+
+

Completing sign-in...

+
+ + + +`, + { headers: { "Content-Type": "text/html" } }, + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx index d4aa26480d..62e796b748 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx @@ -47,7 +47,10 @@ export type CustomNode = XYNode; export const CustomNode: React.FC> = React.memo( ({ data, id: nodeId, selected }) => { - const { inputSchema, outputSchema } = useCustomNode({ data, nodeId }); + const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({ + data, + nodeId, + }); const isAgent = data.uiType === BlockUIType.AGENT; @@ -98,6 +101,7 @@ export const CustomNode: React.FC> = React.memo( jsonSchema={preprocessInputSchema(inputSchema)} nodeId={nodeId} uiType={data.uiType} + isMCPWithTool={isMCPWithTool} className={cn( "bg-white px-4", isWebhook && "pointer-events-none opacity-50", diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index c4659b8dcf..9a3add62b6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -20,10 +20,8 @@ type Props = { export const NodeHeader = ({ data, nodeId }: Props) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); - const title = - (data.metadata?.customized_name as string) || - data.hardcodedValues?.agent_name || - data.title; + + const title = (data.metadata?.customized_name as string) || data.title; const [isEditingTitle, setIsEditingTitle] = useState(false); const [editedTitle, setEditedTitle] = useState(title); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx index e58d0ab12b..050515a02f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx @@ -3,6 +3,34 @@ import { CustomNodeData } from "./CustomNode"; import { BlockUIType } from "../../../types"; import { useMemo } from "react"; import { mergeSchemaForResolution } from "./helpers"; +/** + * Build a dynamic input schema for MCP blocks. + * + * When a tool has been selected (tool_input_schema is populated), the block + * renders the selected tool's input parameters *plus* the credentials field + * so users can select/change the OAuth credential used for execution. + * + * Static fields like server_url, selected_tool, available_tools, and + * tool_arguments are hidden because they're pre-configured from the dialog. + */ +function buildMCPInputSchema( + toolInputSchema: Record, + blockInputSchema: Record, +): Record { + // Extract the credentials field from the block's original input schema + const credentialsSchema = + blockInputSchema?.properties?.credentials ?? undefined; + + return { + type: "object", + properties: { + // Credentials field first so the dropdown appears at the top + ...(credentialsSchema ? { credentials: credentialsSchema } : {}), + ...(toolInputSchema.properties ?? {}), + }, + required: [...(toolInputSchema.required ?? [])], + }; +} export const useCustomNode = ({ data, @@ -19,10 +47,18 @@ export const useCustomNode = ({ ); const isAgent = data.uiType === BlockUIType.AGENT; + const isMCPWithTool = + data.uiType === BlockUIType.MCP_TOOL && + !!data.hardcodedValues?.tool_input_schema?.properties; const currentInputSchema = isAgent ? (data.hardcodedValues.input_schema ?? {}) - : data.inputSchema; + : isMCPWithTool + ? buildMCPInputSchema( + data.hardcodedValues.tool_input_schema, + data.inputSchema, + ) + : data.inputSchema; const currentOutputSchema = isAgent ? (data.hardcodedValues.output_schema ?? {}) : data.outputSchema; @@ -54,5 +90,6 @@ export const useCustomNode = ({ return { inputSchema, outputSchema, + isMCPWithTool, }; }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx index d6a3fabffa..77b21dda92 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx @@ -9,39 +9,72 @@ interface FormCreatorProps { jsonSchema: RJSFSchema; nodeId: string; uiType: BlockUIType; + /** When true the block is an MCP Tool with a selected tool. */ + isMCPWithTool?: boolean; showHandles?: boolean; className?: string; } export const FormCreator: React.FC = React.memo( - ({ jsonSchema, nodeId, uiType, showHandles = true, className }) => { + ({ + jsonSchema, + nodeId, + uiType, + isMCPWithTool = false, + showHandles = true, + className, + }) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); const getHardCodedValues = useNodeStore( (state) => state.getHardCodedValues, ); + const isAgent = uiType === BlockUIType.AGENT; + const handleChange = ({ formData }: any) => { if ("credentials" in formData && !formData.credentials?.id) { delete formData.credentials; } - const updatedValues = - uiType === BlockUIType.AGENT - ? { - ...getHardCodedValues(nodeId), - inputs: formData, - } - : formData; + let updatedValues; + if (isAgent) { + updatedValues = { + ...getHardCodedValues(nodeId), + inputs: formData, + }; + } else if (isMCPWithTool) { + // Separate credentials from tool arguments — credentials are stored + // at the top level of hardcodedValues, not inside tool_arguments. + const { credentials, ...toolArgs } = formData; + updatedValues = { + ...getHardCodedValues(nodeId), + tool_arguments: toolArgs, + ...(credentials?.id ? { credentials } : {}), + }; + } else { + updatedValues = formData; + } updateNodeData(nodeId, { hardcodedValues: updatedValues }); }; const hardcodedValues = getHardCodedValues(nodeId); - const initialValues = - uiType === BlockUIType.AGENT - ? (hardcodedValues.inputs ?? {}) - : hardcodedValues; + + let initialValues; + if (isAgent) { + initialValues = hardcodedValues.inputs ?? {}; + } else if (isMCPWithTool) { + // Merge tool arguments with credentials for the form + initialValues = { + ...(hardcodedValues.tool_arguments ?? {}), + ...(hardcodedValues.credentials?.id + ? { credentials: hardcodedValues.credentials } + : {}), + }; + } else { + initialValues = hardcodedValues; + } return (
; + availableTools: Record; + /** Credentials meta from OAuth flow, null for public servers. */ + credentials: CredentialsMetaInput | null; +}; + +interface MCPToolDialogProps { + open: boolean; + onClose: () => void; + onConfirm: (result: MCPToolDialogResult) => void; +} + +type DialogStep = "url" | "tool"; + +export function MCPToolDialog({ + open, + onClose, + onConfirm, +}: MCPToolDialogProps) { + const allProviders = useContext(CredentialsProvidersContext); + + const [step, setStep] = useState("url"); + const [serverUrl, setServerUrl] = useState(""); + const [tools, setTools] = useState([]); + const [serverName, setServerName] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [authRequired, setAuthRequired] = useState(false); + const [oauthLoading, setOauthLoading] = useState(false); + const [showManualToken, setShowManualToken] = useState(false); + const [manualToken, setManualToken] = useState(""); + const [selectedTool, setSelectedTool] = useState( + null, + ); + const [credentials, setCredentials] = useState( + null, + ); + + const startOAuthRef = useRef(false); + const oauthAbortRef = useRef<((reason?: string) => void) | null>(null); + + // Clean up on unmount + useEffect(() => { + return () => { + oauthAbortRef.current?.(); + }; + }, []); + + const reset = useCallback(() => { + oauthAbortRef.current?.(); + oauthAbortRef.current = null; + setStep("url"); + setServerUrl(""); + setManualToken(""); + setTools([]); + setServerName(null); + setLoading(false); + setError(null); + setAuthRequired(false); + setOauthLoading(false); + setShowManualToken(false); + setSelectedTool(null); + setCredentials(null); + }, []); + + const handleClose = useCallback(() => { + reset(); + onClose(); + }, [reset, onClose]); + + const discoverTools = useCallback(async (url: string, authToken?: string) => { + setLoading(true); + setError(null); + try { + const response = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: url, + auth_token: authToken || null, + }); + if (response.status !== 200) throw response.data; + setTools(response.data.tools); + setServerName(response.data.server_name ?? null); + setAuthRequired(false); + setShowManualToken(false); + setStep("tool"); + } catch (e: any) { + if (e?.status === 401 || e?.status === 403) { + setAuthRequired(true); + setError(null); + // Automatically start OAuth sign-in instead of requiring a second click + setLoading(false); + startOAuthRef.current = true; + return; + } else { + const message = + e?.message || e?.detail || "Failed to connect to MCP server"; + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setLoading(false); + } + }, []); + + const handleDiscoverTools = useCallback(() => { + if (!serverUrl.trim()) return; + discoverTools(serverUrl.trim(), manualToken.trim() || undefined); + }, [serverUrl, manualToken, discoverTools]); + + const handleOAuthSignIn = useCallback(async () => { + if (!serverUrl.trim()) return; + setError(null); + + // Abort any previous OAuth flow + oauthAbortRef.current?.(); + + setOauthLoading(true); + + try { + const loginResponse = await postV2InitiateOauthLoginForAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (loginResponse.status !== 200) throw loginResponse.data; + const { login_url, state_token } = loginResponse.data; + + const { promise, cleanup } = openOAuthPopup(login_url, { + stateToken: state_token, + useCrossOriginListeners: true, + }); + oauthAbortRef.current = cleanup.abort; + + const result = await promise; + + // Exchange code for tokens via the credentials provider (updates cache) + setLoading(true); + setOauthLoading(false); + + const mcpProvider = allProviders?.["mcp"]; + let callbackResult; + if (mcpProvider) { + callbackResult = await mcpProvider.mcpOAuthCallback( + result.code, + state_token, + ); + } else { + const cbResponse = await postV2ExchangeOauthCodeForMcpTokens({ + code: result.code, + state_token, + }); + if (cbResponse.status !== 200) throw cbResponse.data; + callbackResult = cbResponse.data; + } + + setCredentials({ + id: callbackResult.id, + provider: callbackResult.provider, + type: callbackResult.type, + title: callbackResult.title, + }); + setAuthRequired(false); + + // Discover tools now that we're authenticated + const toolsResponse = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (toolsResponse.status !== 200) throw toolsResponse.data; + setTools(toolsResponse.data.tools); + setServerName(toolsResponse.data.server_name ?? null); + setStep("tool"); + } catch (e: any) { + // If server doesn't support OAuth → show manual token entry + if (e?.status === 400) { + setShowManualToken(true); + setError( + "This server does not support OAuth sign-in. Please enter a token manually.", + ); + } else if (e?.message === "OAuth flow timed out") { + setError("OAuth sign-in timed out. Please try again."); + } else { + const status = e?.status; + let message: string; + if (status === 401 || status === 403) { + message = + "Authentication succeeded but the server still rejected the request. " + + "The token audience may not match. Please try again."; + } else { + message = e?.message || e?.detail || "Failed to complete sign-in"; + } + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setOauthLoading(false); + setLoading(false); + oauthAbortRef.current = null; + } + }, [serverUrl, allProviders]); + + // Auto-start OAuth sign-in when server returns 401/403 + useEffect(() => { + if (authRequired && startOAuthRef.current) { + startOAuthRef.current = false; + handleOAuthSignIn(); + } + }, [authRequired, handleOAuthSignIn]); + + const handleConfirm = useCallback(() => { + if (!selectedTool) return; + + const availableTools: Record = {}; + for (const t of tools) { + availableTools[t.name] = { + description: t.description, + input_schema: t.input_schema, + }; + } + + onConfirm({ + serverUrl: serverUrl.trim(), + serverName, + selectedTool: selectedTool.name, + toolInputSchema: selectedTool.input_schema, + availableTools, + credentials, + }); + reset(); + }, [ + selectedTool, + tools, + serverUrl, + serverName, + credentials, + onConfirm, + reset, + ]); + + return ( + !isOpen && handleClose()}> + + + + {step === "url" + ? "Connect to MCP Server" + : `Select a Tool${serverName ? ` — ${serverName}` : ""}`} + + + {step === "url" + ? "Enter the URL of an MCP server to discover its available tools." + : `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`} + + + + {step === "url" && ( +
+
+ + setServerUrl(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ + {/* Auth required: show manual token option */} + {authRequired && !showManualToken && ( + + )} + + {/* Manual token entry — only visible when expanded */} + {showManualToken && ( +
+ + setManualToken(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ )} + + {error &&

{error}

} +
+ )} + + {step === "tool" && ( + +
+ {tools.map((tool) => ( + setSelectedTool(tool)} + /> + ))} +
+
+ )} + + + {step === "tool" && ( + + )} + + {step === "url" && ( + + )} + {step === "tool" && ( + + )} + +
+
+ ); +} + +// --------------- Tool Card Component --------------- // + +/** Truncate a description to a reasonable length for the collapsed view. */ +function truncateDescription(text: string, maxLen = 120): string { + if (text.length <= maxLen) return text; + return text.slice(0, maxLen).trimEnd() + "…"; +} + +/** Pretty-print a JSON Schema type for a parameter. */ +function schemaTypeLabel(schema: Record): string { + if (schema.type) return schema.type; + if (schema.anyOf) + return schema.anyOf.map((s: any) => s.type ?? "any").join(" | "); + if (schema.oneOf) + return schema.oneOf.map((s: any) => s.type ?? "any").join(" | "); + return "any"; +} + +function MCPToolCard({ + tool, + selected, + onSelect, +}: { + tool: MCPToolResponse; + selected: boolean; + onSelect: () => void; +}) { + const [expanded, setExpanded] = useState(false); + const schema = tool.input_schema as Record; + const properties = schema?.properties ?? {}; + const required = new Set(schema?.required ?? []); + const paramNames = Object.keys(properties); + + // Strip XML-like tags from description for cleaner display. + // Loop to handle nested tags like ipt> (CodeQL fix). + let cleanDescription = tool.description ?? ""; + let prev = ""; + while (prev !== cleanDescription) { + prev = cleanDescription; + cleanDescription = cleanDescription.replace(/<[^>]*>/g, ""); + } + cleanDescription = cleanDescription.trim(); + + return ( + + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx index 10f4fc8a44..07c6795808 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx @@ -1,7 +1,7 @@ import { Button } from "@/components/__legacy__/ui/button"; import { Skeleton } from "@/components/__legacy__/ui/skeleton"; import { beautifyString, cn } from "@/lib/utils"; -import React, { ButtonHTMLAttributes } from "react"; +import React, { ButtonHTMLAttributes, useCallback, useState } from "react"; import { highlightText } from "./helpers"; import { PlusIcon } from "@phosphor-icons/react"; import { BlockInfo } from "@/app/api/__generated__/models/blockInfo"; @@ -9,6 +9,12 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore"; import { blockDragPreviewStyle } from "./style"; import { useReactFlow } from "@xyflow/react"; import { useNodeStore } from "../../../stores/nodeStore"; +import { BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api"; +import { + MCPToolDialog, + type MCPToolDialogResult, +} from "@/app/(platform)/build/components/MCPToolDialog"; + interface Props extends ButtonHTMLAttributes { title?: string; description?: string; @@ -33,22 +39,86 @@ export const Block: BlockComponent = ({ ); const { setViewport } = useReactFlow(); const { addBlock } = useNodeStore(); + const [mcpDialogOpen, setMcpDialogOpen] = useState(false); + + const isMCPBlock = blockData.uiType === BlockUIType.MCP_TOOL; + + const addBlockAndCenter = useCallback( + (block: BlockInfo, hardcodedValues?: Record) => { + const customNode = addBlock(block, hardcodedValues); + setTimeout(() => { + setViewport( + { + x: -customNode.position.x * 0.8 + window.innerWidth / 2, + y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, + zoom: 0.8, + }, + { duration: 500 }, + ); + }, 50); + return customNode; + }, + [addBlock, setViewport], + ); + + const updateNodeData = useNodeStore((state) => state.updateNodeData); + + const handleMCPToolConfirm = useCallback( + (result: MCPToolDialogResult) => { + // Derive a display label: prefer server name, fall back to URL hostname. + let serverLabel = result.serverName; + if (!serverLabel) { + try { + serverLabel = new URL(result.serverUrl).hostname; + } catch { + serverLabel = "MCP"; + } + } + + const customNode = addBlockAndCenter(blockData, { + server_url: result.serverUrl, + server_name: serverLabel, + selected_tool: result.selectedTool, + tool_input_schema: result.toolInputSchema, + available_tools: result.availableTools, + credentials: result.credentials ?? undefined, + }); + if (customNode) { + const title = result.selectedTool + ? `${serverLabel}: ${beautifyString(result.selectedTool)}` + : undefined; + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + credentials_optional: true, + ...(title && { customized_name: title }), + }, + }); + } + setMcpDialogOpen(false); + }, + [addBlockAndCenter, blockData, updateNodeData], + ); const handleClick = () => { - const customNode = addBlock(blockData); - setTimeout(() => { - setViewport( - { - x: -customNode.position.x * 0.8 + window.innerWidth / 2, - y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, - zoom: 0.8, + if (isMCPBlock) { + setMcpDialogOpen(true); + return; + } + const customNode = addBlockAndCenter(blockData); + // Set customized_name for agent blocks so the agent's name persists + if (customNode && blockData.id === SpecialBlockID.AGENT) { + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + customized_name: blockData.name, }, - { duration: 500 }, - ); - }, 50); + }); + } }; const handleDragStart = (e: React.DragEvent) => { + if (isMCPBlock) return; e.dataTransfer.effectAllowed = "copy"; e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData)); @@ -71,46 +141,56 @@ export const Block: BlockComponent = ({ : undefined; return ( -
- +
+ {title && ( + + {highlightText(beautifyString(title), highlightedText)} + + )} + {description && ( + + {highlightText(description, highlightedText)} + + )} +
+
+ +
+ + {isMCPBlock && ( + setMcpDialogOpen(false)} + onConfirm={handleMCPToolConfirm} + /> + )} + ); }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts index 2fde427330..0f5021351d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts @@ -9,4 +9,5 @@ export enum BlockUIType { AGENT = "Agent", AI = "AI", AYRSHARE = "Ayrshare", + MCP_TOOL = "MCP Tool", } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx index 71ade81a9f..c118057963 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx @@ -15,11 +15,16 @@ import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai"; import { useEffect, useRef, useState } from "react"; import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../../tools/RunAgent/RunAgent"; import { RunBlockTool } from "../../tools/RunBlock/RunBlock"; import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs"; +import { GenericTool } from "../../tools/GenericTool/GenericTool"; import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput"; // --------------------------------------------------------------------------- @@ -254,7 +259,31 @@ export const ChatMessagesContainer = ({ part={part as ToolUIPart} /> ); + case "tool-search_feature_requests": + return ( + + ); + case "tool-create_feature_request": + return ( + + ); default: + // Render a generic tool indicator for SDK built-in + // tools (Read, Glob, Grep, etc.) or any unrecognized tool + if (part.type.startsWith("tool-")) { + return ( + + ); + } return null; } })} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx index 6030665f1c..8a35f939ca 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx @@ -14,6 +14,10 @@ import { Text } from "@/components/atoms/Text/Text"; import { CopilotChatActionsProvider } from "../components/CopilotChatActionsProvider/CopilotChatActionsProvider"; import { CreateAgentTool } from "../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../tools/RunAgent/RunAgent"; @@ -45,6 +49,8 @@ const SECTIONS = [ "Tool: Create Agent", "Tool: Edit Agent", "Tool: View Agent Output", + "Tool: Search Feature Requests", + "Tool: Create Feature Request", "Full Conversation Example", ] as const; @@ -1421,6 +1427,235 @@ export default function StyleguidePage() { + {/* ============================================================= */} + {/* SEARCH FEATURE REQUESTS */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ + {/* ============================================================= */} + {/* CREATE FEATURE REQUEST */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ {/* ============================================================= */} {/* FULL CONVERSATION EXAMPLE */} {/* ============================================================= */} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx new file mode 100644 index 0000000000..fcd4624b6a --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx @@ -0,0 +1,227 @@ +"use client"; + +import type { ToolUIPart } from "ai"; +import { useMemo } from "react"; + +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; +import { + ContentBadge, + ContentCard, + ContentCardDescription, + ContentCardHeader, + ContentCardTitle, + ContentGrid, + ContentMessage, + ContentSuggestionsList, +} from "../../components/ToolAccordion/AccordionContent"; +import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion"; +import { + AccordionIcon, + getAccordionTitle, + getAnimationText, + getFeatureRequestOutput, + isCreatedOutput, + isErrorOutput, + isNoResultsOutput, + isSearchResultsOutput, + ToolIcon, + type FeatureRequestToolType, +} from "./helpers"; + +export interface FeatureRequestToolPart { + type: FeatureRequestToolType; + toolCallId: string; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +interface Props { + part: FeatureRequestToolPart; +} + +function truncate(text: string, maxChars: number): string { + const trimmed = text.trim(); + if (trimmed.length <= maxChars) return trimmed; + return `${trimmed.slice(0, maxChars).trimEnd()}…`; +} + +export function SearchFeatureRequestsTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const searchOutput = + isOutputAvailable && output && isSearchResultsOutput(output) + ? output + : null; + const noResultsOutput = + isOutputAvailable && output && isNoResultsOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && + ((!!searchOutput && searchOutput.count > 0) || + !!noResultsOutput || + !!errorOutput); + + const accordionDescription = + hasExpandableContent && searchOutput + ? `Found ${searchOutput.count} result${searchOutput.count === 1 ? "" : "s"} for "${searchOutput.query}"` + : hasExpandableContent && (noResultsOutput || errorOutput) + ? ((noResultsOutput ?? errorOutput)?.message ?? null) + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {searchOutput && ( + + {searchOutput.results.map((r) => ( + + + {r.title} + + {r.description && ( + + {truncate(r.description, 200)} + + )} + + ))} + + )} + + {noResultsOutput && ( +
+ {noResultsOutput.message} + {noResultsOutput.suggestions && + noResultsOutput.suggestions.length > 0 && ( + + )} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} + +export function CreateFeatureRequestTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const createdOutput = + isOutputAvailable && output && isCreatedOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && (!!createdOutput || !!errorOutput); + + const accordionDescription = + hasExpandableContent && createdOutput + ? createdOutput.issue_title + : hasExpandableContent && errorOutput + ? errorOutput.message + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {createdOutput && ( + + + {createdOutput.issue_title} + +
+ + {createdOutput.is_new_issue ? "New" : "Existing"} + +
+ {createdOutput.message} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx new file mode 100644 index 0000000000..75133905b1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx @@ -0,0 +1,271 @@ +import { + CheckCircleIcon, + LightbulbIcon, + MagnifyingGlassIcon, + PlusCircleIcon, +} from "@phosphor-icons/react"; +import type { ToolUIPart } from "ai"; + +/* ------------------------------------------------------------------ */ +/* Types (local until API client is regenerated) */ +/* ------------------------------------------------------------------ */ + +interface FeatureRequestInfo { + id: string; + identifier: string; + title: string; + description?: string | null; +} + +export interface FeatureRequestSearchResponse { + type: "feature_request_search"; + message: string; + results: FeatureRequestInfo[]; + count: number; + query: string; +} + +export interface FeatureRequestCreatedResponse { + type: "feature_request_created"; + message: string; + issue_id: string; + issue_identifier: string; + issue_title: string; + issue_url: string; + is_new_issue: boolean; + customer_name: string; +} + +interface NoResultsResponse { + type: "no_results"; + message: string; + suggestions?: string[]; +} + +interface ErrorResponse { + type: "error"; + message: string; + error?: string; +} + +export type FeatureRequestOutput = + | FeatureRequestSearchResponse + | FeatureRequestCreatedResponse + | NoResultsResponse + | ErrorResponse; + +export type FeatureRequestToolType = + | "tool-search_feature_requests" + | "tool-create_feature_request" + | string; + +/* ------------------------------------------------------------------ */ +/* Output parsing */ +/* ------------------------------------------------------------------ */ + +function parseOutput(output: unknown): FeatureRequestOutput | null { + if (!output) return null; + if (typeof output === "string") { + const trimmed = output.trim(); + if (!trimmed) return null; + try { + return parseOutput(JSON.parse(trimmed) as unknown); + } catch { + return null; + } + } + if (typeof output === "object") { + const type = (output as { type?: unknown }).type; + if ( + type === "feature_request_search" || + type === "feature_request_created" || + type === "no_results" || + type === "error" + ) { + return output as FeatureRequestOutput; + } + // Fallback structural checks + if ("results" in output && "query" in output) + return output as FeatureRequestSearchResponse; + if ("issue_identifier" in output) + return output as FeatureRequestCreatedResponse; + if ("suggestions" in output && !("error" in output)) + return output as NoResultsResponse; + if ("error" in output || "details" in output) + return output as ErrorResponse; + } + return null; +} + +export function getFeatureRequestOutput( + part: unknown, +): FeatureRequestOutput | null { + if (!part || typeof part !== "object") return null; + return parseOutput((part as { output?: unknown }).output); +} + +/* ------------------------------------------------------------------ */ +/* Type guards */ +/* ------------------------------------------------------------------ */ + +export function isSearchResultsOutput( + output: FeatureRequestOutput, +): output is FeatureRequestSearchResponse { + return ( + output.type === "feature_request_search" || + ("results" in output && "query" in output) + ); +} + +export function isCreatedOutput( + output: FeatureRequestOutput, +): output is FeatureRequestCreatedResponse { + return ( + output.type === "feature_request_created" || "issue_identifier" in output + ); +} + +export function isNoResultsOutput( + output: FeatureRequestOutput, +): output is NoResultsResponse { + return ( + output.type === "no_results" || + ("suggestions" in output && !("error" in output)) + ); +} + +export function isErrorOutput( + output: FeatureRequestOutput, +): output is ErrorResponse { + return output.type === "error" || "error" in output; +} + +/* ------------------------------------------------------------------ */ +/* Accordion metadata */ +/* ------------------------------------------------------------------ */ + +export function getAccordionTitle( + toolType: FeatureRequestToolType, + output: FeatureRequestOutput, +): string { + if (toolType === "tool-search_feature_requests") { + if (isSearchResultsOutput(output)) return "Feature requests"; + if (isNoResultsOutput(output)) return "No feature requests found"; + return "Feature request search error"; + } + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to feature request"; + } + if (isErrorOutput(output)) return "Feature request error"; + return "Feature request"; +} + +/* ------------------------------------------------------------------ */ +/* Animation text */ +/* ------------------------------------------------------------------ */ + +interface AnimationPart { + type: FeatureRequestToolType; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +export function getAnimationText(part: AnimationPart): string { + if (part.type === "tool-search_feature_requests") { + const query = (part.input as { query?: string } | undefined)?.query?.trim(); + const queryText = query ? ` for "${query}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Searching feature requests${queryText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Searching feature requests${queryText}`; + if (isSearchResultsOutput(output)) { + return `Found ${output.count} feature request${output.count === 1 ? "" : "s"}${queryText}`; + } + if (isNoResultsOutput(output)) + return `No feature requests found${queryText}`; + return `Error searching feature requests${queryText}`; + } + case "output-error": + return `Error searching feature requests${queryText}`; + default: + return "Searching feature requests"; + } + } + + // create_feature_request + const title = (part.input as { title?: string } | undefined)?.title?.trim(); + const titleText = title ? ` "${title}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Creating feature request${titleText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Creating feature request${titleText}`; + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to existing feature request"; + } + if (isErrorOutput(output)) return "Error creating feature request"; + return `Created feature request${titleText}`; + } + case "output-error": + return "Error creating feature request"; + default: + return "Creating feature request"; + } +} + +/* ------------------------------------------------------------------ */ +/* Icons */ +/* ------------------------------------------------------------------ */ + +export function ToolIcon({ + toolType, + isStreaming, + isError, +}: { + toolType: FeatureRequestToolType; + isStreaming?: boolean; + isError?: boolean; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? PlusCircleIcon + : MagnifyingGlassIcon; + + return ( + + ); +} + +export function AccordionIcon({ + toolType, +}: { + toolType: FeatureRequestToolType; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? CheckCircleIcon + : LightbulbIcon; + return ; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx new file mode 100644 index 0000000000..677f1d01d1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx @@ -0,0 +1,63 @@ +"use client"; + +import { ToolUIPart } from "ai"; +import { GearIcon } from "@phosphor-icons/react"; +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; + +interface Props { + part: ToolUIPart; +} + +function extractToolName(part: ToolUIPart): string { + // ToolUIPart.type is "tool-{name}", extract the name portion. + return part.type.replace(/^tool-/, ""); +} + +function formatToolName(name: string): string { + // "search_docs" → "Search docs", "Read" → "Read" + return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase()); +} + +function getAnimationText(part: ToolUIPart): string { + const label = formatToolName(extractToolName(part)); + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Running ${label}…`; + case "output-available": + return `${label} completed`; + case "output-error": + return `${label} failed`; + default: + return `Running ${label}…`; + } +} + +export function GenericTool({ part }: Props) { + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = part.state === "output-error"; + + return ( +
+
+ + +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx index e1cb030449..6e2cbe90d7 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx @@ -3,6 +3,7 @@ import type { ToolUIPart } from "ai"; import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion"; +import { BlockDetailsCard } from "./components/BlockDetailsCard/BlockDetailsCard"; import { BlockOutputCard } from "./components/BlockOutputCard/BlockOutputCard"; import { ErrorCard } from "./components/ErrorCard/ErrorCard"; import { SetupRequirementsCard } from "./components/SetupRequirementsCard/SetupRequirementsCard"; @@ -11,6 +12,7 @@ import { getAnimationText, getRunBlockToolOutput, isRunBlockBlockOutput, + isRunBlockDetailsOutput, isRunBlockErrorOutput, isRunBlockSetupRequirementsOutput, ToolIcon, @@ -41,6 +43,7 @@ export function RunBlockTool({ part }: Props) { part.state === "output-available" && !!output && (isRunBlockBlockOutput(output) || + isRunBlockDetailsOutput(output) || isRunBlockSetupRequirementsOutput(output) || isRunBlockErrorOutput(output)); @@ -58,6 +61,10 @@ export function RunBlockTool({ part }: Props) { {isRunBlockBlockOutput(output) && } + {isRunBlockDetailsOutput(output) && ( + + )} + {isRunBlockSetupRequirementsOutput(output) && ( )} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.stories.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.stories.tsx new file mode 100644 index 0000000000..6e133ca93b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.stories.tsx @@ -0,0 +1,188 @@ +import type { Meta, StoryObj } from "@storybook/nextjs"; +import { ResponseType } from "@/app/api/__generated__/models/responseType"; +import type { BlockDetailsResponse } from "../../helpers"; +import { BlockDetailsCard } from "./BlockDetailsCard"; + +const meta: Meta = { + title: "Copilot/RunBlock/BlockDetailsCard", + component: BlockDetailsCard, + parameters: { + layout: "centered", + }, + tags: ["autodocs"], + decorators: [ + (Story) => ( +
+ +
+ ), + ], +}; + +export default meta; +type Story = StoryObj; + +const baseBlock: BlockDetailsResponse = { + type: ResponseType.block_details, + message: + "Here are the details for the GetWeather block. Provide the required inputs to run it.", + session_id: "session-123", + user_authenticated: true, + block: { + id: "block-abc-123", + name: "GetWeather", + description: "Fetches current weather data for a given location.", + inputs: { + type: "object", + properties: { + location: { + title: "Location", + type: "string", + description: + "City name or coordinates (e.g. 'London' or '51.5,-0.1')", + }, + units: { + title: "Units", + type: "string", + description: "Temperature units: 'metric' or 'imperial'", + }, + }, + required: ["location"], + }, + outputs: { + type: "object", + properties: { + temperature: { + title: "Temperature", + type: "number", + description: "Current temperature in the requested units", + }, + condition: { + title: "Condition", + type: "string", + description: "Weather condition description (e.g. 'Sunny', 'Rain')", + }, + }, + }, + credentials: [], + }, +}; + +export const Default: Story = { + args: { + output: baseBlock, + }, +}; + +export const InputsOnly: Story = { + args: { + output: { + ...baseBlock, + message: "This block requires inputs. No outputs are defined.", + block: { + ...baseBlock.block, + outputs: {}, + }, + }, + }, +}; + +export const OutputsOnly: Story = { + args: { + output: { + ...baseBlock, + message: "This block has no required inputs.", + block: { + ...baseBlock.block, + inputs: {}, + }, + }, + }, +}; + +export const ManyFields: Story = { + args: { + output: { + ...baseBlock, + message: "Block with many input and output fields.", + block: { + ...baseBlock.block, + name: "SendEmail", + description: "Sends an email via SMTP.", + inputs: { + type: "object", + properties: { + to: { + title: "To", + type: "string", + description: "Recipient email address", + }, + subject: { + title: "Subject", + type: "string", + description: "Email subject line", + }, + body: { + title: "Body", + type: "string", + description: "Email body content", + }, + cc: { + title: "CC", + type: "string", + description: "CC recipients (comma-separated)", + }, + bcc: { + title: "BCC", + type: "string", + description: "BCC recipients (comma-separated)", + }, + }, + required: ["to", "subject", "body"], + }, + outputs: { + type: "object", + properties: { + message_id: { + title: "Message ID", + type: "string", + description: "Unique ID of the sent email", + }, + status: { + title: "Status", + type: "string", + description: "Delivery status", + }, + }, + }, + }, + }, + }, +}; + +export const NoFieldDescriptions: Story = { + args: { + output: { + ...baseBlock, + message: "Fields without descriptions.", + block: { + ...baseBlock.block, + name: "SimpleBlock", + inputs: { + type: "object", + properties: { + input_a: { title: "Input A", type: "string" }, + input_b: { title: "Input B", type: "number" }, + }, + required: ["input_a"], + }, + outputs: { + type: "object", + properties: { + result: { title: "Result", type: "string" }, + }, + }, + }, + }, + }, +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.tsx new file mode 100644 index 0000000000..fdbf115222 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.tsx @@ -0,0 +1,103 @@ +"use client"; + +import type { BlockDetailsResponse } from "../../helpers"; +import { + ContentBadge, + ContentCard, + ContentCardDescription, + ContentCardTitle, + ContentGrid, + ContentMessage, +} from "../../../../components/ToolAccordion/AccordionContent"; + +interface Props { + output: BlockDetailsResponse; +} + +function SchemaFieldList({ + title, + properties, + required, +}: { + title: string; + properties: Record; + required?: string[]; +}) { + const entries = Object.entries(properties); + if (entries.length === 0) return null; + + const requiredSet = new Set(required ?? []); + + return ( + + {title} +
+ {entries.map(([name, schema]) => { + const field = schema as Record | undefined; + const fieldTitle = + typeof field?.title === "string" ? field.title : name; + const fieldType = + typeof field?.type === "string" ? field.type : "unknown"; + const description = + typeof field?.description === "string" + ? field.description + : undefined; + + return ( +
+
+ + {fieldTitle} + +
+ {fieldType} + {requiredSet.has(name) && ( + Required + )} +
+
+ {description && ( + + {description} + + )} +
+ ); + })} +
+
+ ); +} + +export function BlockDetailsCard({ output }: Props) { + const inputs = output.block.inputs as { + properties?: Record; + required?: string[]; + } | null; + const outputs = output.block.outputs as { + properties?: Record; + required?: string[]; + } | null; + + return ( + + {output.message} + + {inputs?.properties && Object.keys(inputs.properties).length > 0 && ( + + )} + + {outputs?.properties && Object.keys(outputs.properties).length > 0 && ( + + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx index b8625988cd..6e56154a5e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx @@ -10,18 +10,37 @@ import { import type { ToolUIPart } from "ai"; import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader"; +/** Block details returned on first run_block attempt (before input_data provided). */ +export interface BlockDetailsResponse { + type: typeof ResponseType.block_details; + message: string; + session_id?: string | null; + block: { + id: string; + name: string; + description: string; + inputs: Record; + outputs: Record; + credentials: unknown[]; + }; + user_authenticated: boolean; +} + export interface RunBlockInput { block_id?: string; + block_name?: string; input_data?: Record; } export type RunBlockToolOutput = | SetupRequirementsResponse + | BlockDetailsResponse | BlockOutputResponse | ErrorResponse; const RUN_BLOCK_OUTPUT_TYPES = new Set([ ResponseType.setup_requirements, + ResponseType.block_details, ResponseType.block_output, ResponseType.error, ]); @@ -35,6 +54,15 @@ export function isRunBlockSetupRequirementsOutput( ); } +export function isRunBlockDetailsOutput( + output: RunBlockToolOutput, +): output is BlockDetailsResponse { + return ( + output.type === ResponseType.block_details || + ("block" in output && typeof output.block === "object") + ); +} + export function isRunBlockBlockOutput( output: RunBlockToolOutput, ): output is BlockOutputResponse { @@ -64,6 +92,7 @@ function parseOutput(output: unknown): RunBlockToolOutput | null { return output as RunBlockToolOutput; } if ("block_id" in output) return output as BlockOutputResponse; + if ("block" in output) return output as BlockDetailsResponse; if ("setup_info" in output) return output as SetupRequirementsResponse; if ("error" in output || "details" in output) return output as ErrorResponse; @@ -84,17 +113,25 @@ export function getAnimationText(part: { output?: unknown; }): string { const input = part.input as RunBlockInput | undefined; + const blockName = input?.block_name?.trim(); const blockId = input?.block_id?.trim(); - const blockText = blockId ? ` "${blockId}"` : ""; + // Prefer block_name if available, otherwise fall back to block_id + const blockText = blockName + ? ` "${blockName}"` + : blockId + ? ` "${blockId}"` + : ""; switch (part.state) { case "input-streaming": case "input-available": - return `Running the block${blockText}`; + return `Running${blockText}`; case "output-available": { const output = parseOutput(part.output); - if (!output) return `Running the block${blockText}`; + if (!output) return `Running${blockText}`; if (isRunBlockBlockOutput(output)) return `Ran "${output.block_name}"`; + if (isRunBlockDetailsOutput(output)) + return `Details for "${output.block.name}"`; if (isRunBlockSetupRequirementsOutput(output)) { return `Setup needed for "${output.setup_info.agent_name}"`; } @@ -158,6 +195,21 @@ export function getAccordionMeta(output: RunBlockToolOutput): { }; } + if (isRunBlockDetailsOutput(output)) { + const inputKeys = Object.keys( + (output.block.inputs as { properties?: Record }) + ?.properties ?? {}, + ); + return { + icon, + title: output.block.name, + description: + inputKeys.length > 0 + ? `${inputKeys.length} input field${inputKeys.length === 1 ? "" : "s"} available` + : output.message, + }; + } + if (isRunBlockSetupRequirementsOutput(output)) { const missingCredsCount = Object.keys( (output.setup_info.user_readiness?.missing_credentials ?? {}) as Record< diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 1f975ff575..5d91f67981 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1053,6 +1053,7 @@ "$ref": "#/components/schemas/ClarificationNeededResponse" }, { "$ref": "#/components/schemas/BlockListResponse" }, + { "$ref": "#/components/schemas/BlockDetailsResponse" }, { "$ref": "#/components/schemas/BlockOutputResponse" }, { "$ref": "#/components/schemas/DocSearchResultsResponse" }, { "$ref": "#/components/schemas/DocPageResponse" }, @@ -4625,6 +4626,128 @@ } } }, + "/api/mcp/discover-tools": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Discover available tools on an MCP server", + "description": "Connect to an MCP server and return its available tools.\n\nIf the user has a stored MCP credential for this server URL, it will be\nused automatically — no need to pass an explicit auth token.", + "operationId": "postV2Discover available tools on an mcp server", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/DiscoverToolsRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DiscoverToolsResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/mcp/oauth/callback": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Exchange OAuth code for MCP tokens", + "description": "Exchange the authorization code for tokens and store the credential.\n\nThe frontend calls this after receiving the OAuth code from the popup.\nOn success, subsequent ``/discover-tools`` calls for the same server URL\nwill automatically use the stored credential.", + "operationId": "postV2Exchange oauth code for mcp tokens", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPOAuthCallbackRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CredentialsMetaResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/mcp/oauth/login": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Initiate OAuth login for an MCP server", + "description": "Discover OAuth metadata from the MCP server and return a login URL.\n\n1. Discovers the protected-resource metadata (RFC 9728)\n2. Fetches the authorization server metadata (RFC 8414)\n3. Performs Dynamic Client Registration (RFC 7591) if available\n4. Returns the authorization URL for the frontend to open in a popup", + "operationId": "postV2Initiate oauth login for an mcp server", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/MCPOAuthLoginRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPOAuthLoginResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, "/api/oauth/app/{client_id}": { "get": { "tags": ["oauth"], @@ -7315,6 +7438,58 @@ "enum": ["run", "byte", "second"], "title": "BlockCostType" }, + "BlockDetails": { + "properties": { + "id": { "type": "string", "title": "Id" }, + "name": { "type": "string", "title": "Name" }, + "description": { "type": "string", "title": "Description" }, + "inputs": { + "additionalProperties": true, + "type": "object", + "title": "Inputs", + "default": {} + }, + "outputs": { + "additionalProperties": true, + "type": "object", + "title": "Outputs", + "default": {} + }, + "credentials": { + "items": { "$ref": "#/components/schemas/CredentialsMetaInput" }, + "type": "array", + "title": "Credentials", + "default": [] + } + }, + "type": "object", + "required": ["id", "name", "description"], + "title": "BlockDetails", + "description": "Detailed block information." + }, + "BlockDetailsResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "block_details" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "block": { "$ref": "#/components/schemas/BlockDetails" }, + "user_authenticated": { + "type": "boolean", + "title": "User Authenticated", + "default": false + } + }, + "type": "object", + "required": ["message", "block"], + "title": "BlockDetailsResponse", + "description": "Response for block details (first run_block attempt)." + }, "BlockInfo": { "properties": { "id": { "type": "string", "title": "Id" }, @@ -7379,29 +7554,24 @@ "input_schema": { "additionalProperties": true, "type": "object", - "title": "Input Schema" + "title": "Input Schema", + "description": "Full JSON schema for block inputs" }, "output_schema": { "additionalProperties": true, "type": "object", - "title": "Output Schema" + "title": "Output Schema", + "description": "Full JSON schema for block outputs" }, "required_inputs": { "items": { "$ref": "#/components/schemas/BlockInputFieldInfo" }, "type": "array", "title": "Required Inputs", - "description": "List of required input fields for this block" + "description": "List of input fields for this block" } }, "type": "object", - "required": [ - "id", - "name", - "description", - "categories", - "input_schema", - "output_schema" - ], + "required": ["id", "name", "description", "categories"], "title": "BlockInfoSummary", "description": "Summary of a block for search results." }, @@ -7447,7 +7617,7 @@ "usage_hint": { "type": "string", "title": "Usage Hint", - "default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the required fields from input_schema." + "default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs." } }, "type": "object", @@ -8017,7 +8187,7 @@ "host": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Host", - "description": "Host pattern for host-scoped credentials" + "description": "Host pattern for host-scoped or MCP server URL for MCP credentials" } }, "type": "object", @@ -8037,6 +8207,45 @@ "required": ["version_counts"], "title": "DeleteGraphResponse" }, + "DiscoverToolsRequest": { + "properties": { + "server_url": { + "type": "string", + "title": "Server Url", + "description": "URL of the MCP server" + }, + "auth_token": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Auth Token", + "description": "Optional Bearer token for authenticated MCP servers" + } + }, + "type": "object", + "required": ["server_url"], + "title": "DiscoverToolsRequest", + "description": "Request to discover tools on an MCP server." + }, + "DiscoverToolsResponse": { + "properties": { + "tools": { + "items": { "$ref": "#/components/schemas/MCPToolResponse" }, + "type": "array", + "title": "Tools" + }, + "server_name": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Server Name" + }, + "protocol_version": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Protocol Version" + } + }, + "type": "object", + "required": ["tools"], + "title": "DiscoverToolsResponse", + "description": "Response containing the list of tools available on an MCP server." + }, "DocPageResponse": { "properties": { "type": { @@ -9808,6 +10017,62 @@ "required": ["login_url", "state_token"], "title": "LoginResponse" }, + "MCPOAuthCallbackRequest": { + "properties": { + "code": { + "type": "string", + "title": "Code", + "description": "Authorization code from OAuth callback" + }, + "state_token": { + "type": "string", + "title": "State Token", + "description": "State token for CSRF verification" + } + }, + "type": "object", + "required": ["code", "state_token"], + "title": "MCPOAuthCallbackRequest", + "description": "Request to exchange an OAuth code for tokens." + }, + "MCPOAuthLoginRequest": { + "properties": { + "server_url": { + "type": "string", + "title": "Server Url", + "description": "URL of the MCP server that requires OAuth" + } + }, + "type": "object", + "required": ["server_url"], + "title": "MCPOAuthLoginRequest", + "description": "Request to start an OAuth flow for an MCP server." + }, + "MCPOAuthLoginResponse": { + "properties": { + "login_url": { "type": "string", "title": "Login Url" }, + "state_token": { "type": "string", "title": "State Token" } + }, + "type": "object", + "required": ["login_url", "state_token"], + "title": "MCPOAuthLoginResponse", + "description": "Response with the OAuth login URL for the user to authenticate." + }, + "MCPToolResponse": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "description": { "type": "string", "title": "Description" }, + "input_schema": { + "additionalProperties": true, + "type": "object", + "title": "Input Schema" + } + }, + "type": "object", + "required": ["name", "description", "input_schema"], + "title": "MCPToolResponse", + "description": "A single MCP tool returned by discovery." + }, "MarketplaceListing": { "properties": { "id": { "type": "string", "title": "Id" }, @@ -11053,6 +11318,7 @@ "agent_saved", "clarification_needed", "block_list", + "block_details", "block_output", "doc_search_results", "doc_page", @@ -11064,7 +11330,12 @@ "operation_started", "operation_pending", "operation_in_progress", - "input_validation_error" + "input_validation_error", + "web_fetch", + "bash_exec", + "operation_status", + "feature_request_search", + "feature_request_created" ], "title": "ResponseType", "description": "Types of tool responses." diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx index 135a960431..22d0a318a9 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx @@ -38,13 +38,8 @@ export function CredentialsGroupedView({ const allProviders = useContext(CredentialsProvidersContext); const { userCredentialFields, systemCredentialFields } = useMemo( - () => - splitCredentialFieldsBySystem( - credentialFields, - allProviders, - inputCredentials, - ), - [credentialFields, allProviders, inputCredentials], + () => splitCredentialFieldsBySystem(credentialFields, allProviders), + [credentialFields, allProviders], ); const hasSystemCredentials = systemCredentialFields.length > 0; @@ -86,11 +81,13 @@ export function CredentialsGroupedView({ const providerNames = schema.credentials_provider || []; const credentialTypes = schema.credentials_types || []; const requiredScopes = schema.credentials_scopes; + const discriminatorValues = schema.discriminator_values; const savedCredential = findSavedCredentialByProviderAndType( providerNames, credentialTypes, requiredScopes, allProviders, + discriminatorValues, ); if (savedCredential) { diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts index 5f439d3a32..2d8d001a72 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts @@ -23,10 +23,35 @@ function hasRequiredScopes( return true; } +/** Check if a credential matches the discriminator values (e.g. MCP server URL). */ +function matchesDiscriminatorValues( + credential: { host?: string | null; provider: string; type: string }, + discriminatorValues?: string[], +) { + // MCP OAuth2 credentials must match by server URL + if (credential.type === "oauth2" && credential.provider === "mcp") { + if (!discriminatorValues || discriminatorValues.length === 0) return false; + return ( + credential.host != null && discriminatorValues.includes(credential.host) + ); + } + // Host-scoped credentials match by host + if (credential.type === "host_scoped" && credential.host) { + if (!discriminatorValues || discriminatorValues.length === 0) return true; + return discriminatorValues.some((v) => { + try { + return new URL(v).hostname === credential.host; + } catch { + return false; + } + }); + } + return true; +} + export function splitCredentialFieldsBySystem( credentialFields: CredentialField[], allProviders: CredentialsProvidersContextType | null, - inputCredentials?: Record, ) { if (!allProviders || credentialFields.length === 0) { return { @@ -52,17 +77,9 @@ export function splitCredentialFieldsBySystem( } } - const sortByUnsetFirst = (a: CredentialField, b: CredentialField) => { - const aIsSet = Boolean(inputCredentials?.[a[0]]); - const bIsSet = Boolean(inputCredentials?.[b[0]]); - - if (aIsSet === bIsSet) return 0; - return aIsSet ? 1 : -1; - }; - return { - userCredentialFields: userFields.sort(sortByUnsetFirst), - systemCredentialFields: systemFields.sort(sortByUnsetFirst), + userCredentialFields: userFields, + systemCredentialFields: systemFields, }; } @@ -160,6 +177,7 @@ export function findSavedCredentialByProviderAndType( credentialTypes: string[], requiredScopes: string[] | undefined, allProviders: CredentialsProvidersContextType | null, + discriminatorValues?: string[], ): SavedCredential | undefined { for (const providerName of providerNames) { const providerData = allProviders?.[providerName]; @@ -176,9 +194,14 @@ export function findSavedCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } @@ -190,9 +213,14 @@ export function findSavedCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } @@ -214,6 +242,7 @@ export function findSavedUserCredentialByProviderAndType( credentialTypes: string[], requiredScopes: string[] | undefined, allProviders: CredentialsProvidersContextType | null, + discriminatorValues?: string[], ): SavedCredential | undefined { for (const providerName of providerNames) { const providerData = allProviders?.[providerName]; @@ -230,9 +259,14 @@ export function findSavedUserCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts index 509713ff1e..9ab2e08141 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts @@ -5,14 +5,14 @@ import { BlockIOCredentialsSubSchema, CredentialsMetaInput, } from "@/lib/autogpt-server-api/types"; +import { postV2InitiateOauthLoginForAnMcpServer } from "@/app/api/__generated__/endpoints/mcp/mcp"; +import { openOAuthPopup } from "@/lib/oauth-popup"; import { useQueryClient } from "@tanstack/react-query"; import { useEffect, useRef, useState } from "react"; import { filterSystemCredentials, getActionButtonText, getSystemCredentials, - OAUTH_TIMEOUT_MS, - OAuthPopupResultMessage, } from "./helpers"; export type CredentialsInputState = ReturnType; @@ -57,6 +57,14 @@ export function useCredentialsInput({ const queryClient = useQueryClient(); const credentials = useCredentials(schema, siblingInputs); const hasAttemptedAutoSelect = useRef(false); + const oauthAbortRef = useRef<((reason?: string) => void) | null>(null); + + // Clean up on unmount + useEffect(() => { + return () => { + oauthAbortRef.current?.(); + }; + }, []); const deleteCredentialsMutation = useDeleteV1DeleteCredentials({ mutation: { @@ -81,11 +89,14 @@ export function useCredentialsInput({ } }, [credentials, onLoaded]); - // Unselect credential if not available + // Unselect credential if not available in the loaded credential list. + // Skip when no credentials have been loaded yet (empty list could mean + // the provider data hasn't finished loading, not that the credential is invalid). useEffect(() => { if (readOnly) return; if (!credentials || !("savedCredentials" in credentials)) return; const availableCreds = credentials.savedCredentials; + if (availableCreds.length === 0) return; if ( selectedCredential && !availableCreds.some((c) => c.id === selectedCredential.id) @@ -110,7 +121,9 @@ export function useCredentialsInput({ if (hasAttemptedAutoSelect.current) return; hasAttemptedAutoSelect.current = true; - if (isOptional) return; + // Auto-select if exactly one credential matches. + // For optional fields with multiple options, let the user choose. + if (isOptional && savedCreds.length > 1) return; const cred = savedCreds[0]; onSelectCredential({ @@ -148,7 +161,9 @@ export function useCredentialsInput({ supportsHostScoped, savedCredentials, oAuthCallback, + mcpOAuthCallback, isSystemProvider, + discriminatorValue, } = credentials; // Split credentials into user and system @@ -157,72 +172,66 @@ export function useCredentialsInput({ async function handleOAuthLogin() { setOAuthError(null); - const { login_url, state_token } = await api.oAuthLogin( - provider, - schema.credentials_scopes, - ); - setOAuth2FlowInProgress(true); - const popup = window.open(login_url, "_blank", "popup=true"); - if (!popup) { - throw new Error( - "Failed to open popup window. Please allow popups for this site.", + // Abort any previous OAuth flow + oauthAbortRef.current?.(); + + // MCP uses dynamic OAuth discovery per server URL + const isMCP = provider === "mcp" && !!discriminatorValue; + + try { + let login_url: string; + let state_token: string; + + if (isMCP) { + const mcpLoginResponse = await postV2InitiateOauthLoginForAnMcpServer({ + server_url: discriminatorValue!, + }); + if (mcpLoginResponse.status !== 200) throw mcpLoginResponse.data; + ({ login_url, state_token } = mcpLoginResponse.data); + } else { + ({ login_url, state_token } = await api.oAuthLogin( + provider, + schema.credentials_scopes, + )); + } + + setOAuth2FlowInProgress(true); + + const { promise, cleanup } = openOAuthPopup(login_url, { + stateToken: state_token, + useCrossOriginListeners: isMCP, + // Standard OAuth uses "oauth_popup_result", MCP uses "mcp_oauth_result" + acceptMessageTypes: isMCP + ? ["mcp_oauth_result"] + : ["oauth_popup_result"], + }); + + oauthAbortRef.current = cleanup.abort; + // Expose abort signal for the waiting modal's cancel button + const controller = new AbortController(); + cleanup.signal.addEventListener("abort", () => + controller.abort("completed"), ); - } + setOAuthPopupController(controller); - const controller = new AbortController(); - setOAuthPopupController(controller); - controller.signal.onabort = () => { - console.debug("OAuth flow aborted"); - setOAuth2FlowInProgress(false); - popup.close(); - }; + const result = await promise; - const handleMessage = async (e: MessageEvent) => { - console.debug("Message received:", e.data); - if ( - typeof e.data != "object" || - !("message_type" in e.data) || - e.data.message_type !== "oauth_popup_result" - ) { - console.debug("Ignoring irrelevant message"); - return; - } + // Exchange code for tokens via the provider (updates credential cache) + const credentialResult = isMCP + ? await mcpOAuthCallback(result.code, state_token) + : await oAuthCallback(result.code, result.state); - if (!e.data.success) { - console.error("OAuth flow failed:", e.data.message); - setOAuthError(`OAuth flow failed: ${e.data.message}`); - setOAuth2FlowInProgress(false); - return; - } - - if (e.data.state !== state_token) { - console.error("Invalid state token received"); - setOAuthError("Invalid state token received"); - setOAuth2FlowInProgress(false); - return; - } - - try { - console.debug("Processing OAuth callback"); - const credentials = await oAuthCallback(e.data.code, e.data.state); - console.debug("OAuth callback processed successfully"); - - // Check if the credential's scopes match the required scopes + // Check if the credential's scopes match the required scopes (skip for MCP) + if (!isMCP) { const requiredScopes = schema.credentials_scopes; if (requiredScopes && requiredScopes.length > 0) { - const grantedScopes = new Set(credentials.scopes || []); + const grantedScopes = new Set(credentialResult.scopes || []); const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf( grantedScopes, ); if (!hasAllRequiredScopes) { - console.error( - `Newly created OAuth credential for ${providerName} has insufficient scopes. Required:`, - requiredScopes, - "Granted:", - credentials.scopes, - ); setOAuthError( "Connection failed: the granted permissions don't match what's required. " + "Please contact the application administrator.", @@ -230,38 +239,28 @@ export function useCredentialsInput({ return; } } + } - onSelectCredential({ - id: credentials.id, - type: "oauth2", - title: credentials.title, - provider, - }); - } catch (error) { - console.error("Error in OAuth callback:", error); + onSelectCredential({ + id: credentialResult.id, + type: "oauth2", + title: credentialResult.title, + provider, + }); + } catch (error) { + if (error instanceof Error && error.message === "OAuth flow timed out") { + setOAuthError("OAuth flow timed out"); + } else { setOAuthError( - `Error in OAuth callback: ${ + `OAuth error: ${ error instanceof Error ? error.message : String(error) }`, ); - } finally { - console.debug("Finalizing OAuth flow"); - setOAuth2FlowInProgress(false); - controller.abort("success"); } - }; - - console.debug("Adding message event listener"); - window.addEventListener("message", handleMessage, { - signal: controller.signal, - }); - - setTimeout(() => { - console.debug("OAuth flow timed out"); - controller.abort("timeout"); + } finally { setOAuth2FlowInProgress(false); - setOAuthError("OAuth flow timed out"); - }, OAUTH_TIMEOUT_MS); + oauthAbortRef.current = null; + } } function handleActionButtonClick() { diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx index 97478e9eaf..a8b3514d41 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx @@ -80,7 +80,7 @@ export default function WrapIfAdditionalTemplate( uiSchema={uiSchema} /> {!isHandleConnected && ( -
+
void; + /** The AbortController signal */ + signal: AbortSignal; +}; + +/** + * Opens an OAuth popup and sets up listeners for the callback result. + * + * Opens a blank popup synchronously (to avoid popup blockers), then navigates + * it to the login URL. Returns a promise that resolves with the OAuth code/state. + * + * @param loginUrl - The OAuth authorization URL to navigate to + * @param options - Configuration for message handling + * @returns Object with `promise` (resolves with OAuth result) and `abort` (cancels flow) + */ +export function openOAuthPopup( + loginUrl: string, + options: OAuthPopupOptions, +): { promise: Promise; cleanup: Cleanup } { + const { + stateToken, + useCrossOriginListeners = false, + broadcastChannelName = "mcp_oauth", + localStorageKey = "mcp_oauth_result", + acceptMessageTypes = ["oauth_popup_result", "mcp_oauth_result"], + timeout = DEFAULT_TIMEOUT_MS, + } = options; + + const controller = new AbortController(); + + // Open popup synchronously (before any async work) to avoid browser popup blockers + const width = 500; + const height = 700; + const left = window.screenX + (window.outerWidth - width) / 2; + const top = window.screenY + (window.outerHeight - height) / 2; + const popup = window.open( + "about:blank", + "_blank", + `width=${width},height=${height},left=${left},top=${top},popup=true,scrollbars=yes`, + ); + + if (popup && !popup.closed) { + popup.location.href = loginUrl; + } else { + // Popup was blocked — open in new tab as fallback + window.open(loginUrl, "_blank"); + } + + // Close popup on abort + controller.signal.addEventListener("abort", () => { + if (popup && !popup.closed) popup.close(); + }); + + // Clear any stale localStorage entry + if (useCrossOriginListeners) { + try { + localStorage.removeItem(localStorageKey); + } catch {} + } + + const promise = new Promise((resolve, reject) => { + let handled = false; + + const handleResult = (data: any) => { + if (handled) return; // Prevent double-handling + + // Validate message type + const messageType = data?.message_type ?? data?.type; + if (!messageType || !acceptMessageTypes.includes(messageType)) return; + + // Validate state token + if (data.state !== stateToken) { + // State mismatch — this message is for a different listener. Ignore silently. + return; + } + + handled = true; + + if (!data.success) { + reject(new Error(data.message || "OAuth authentication failed")); + } else { + resolve({ code: data.code, state: data.state }); + } + + controller.abort("completed"); + }; + + // Listener: postMessage (works for same-origin popups) + window.addEventListener( + "message", + (event: MessageEvent) => { + if (typeof event.data === "object") { + handleResult(event.data); + } + }, + { signal: controller.signal }, + ); + + // Cross-origin listeners for MCP OAuth + if (useCrossOriginListeners) { + // Listener: BroadcastChannel (works across tabs/popups without opener) + try { + const bc = new BroadcastChannel(broadcastChannelName); + bc.onmessage = (event) => handleResult(event.data); + controller.signal.addEventListener("abort", () => bc.close()); + } catch {} + + // Listener: localStorage polling (most reliable cross-tab fallback) + const pollInterval = setInterval(() => { + try { + const stored = localStorage.getItem(localStorageKey); + if (stored) { + const data = JSON.parse(stored); + localStorage.removeItem(localStorageKey); + handleResult(data); + } + } catch {} + }, 500); + controller.signal.addEventListener("abort", () => + clearInterval(pollInterval), + ); + } + + // Timeout + const timeoutId = setTimeout(() => { + if (!handled) { + handled = true; + reject(new Error("OAuth flow timed out")); + controller.abort("timeout"); + } + }, timeout); + controller.signal.addEventListener("abort", () => clearTimeout(timeoutId)); + }); + + return { + promise, + cleanup: { + abort: (reason?: string) => controller.abort(reason || "canceled"), + signal: controller.signal, + }, + }; +} diff --git a/autogpt_platform/frontend/src/middleware.ts b/autogpt_platform/frontend/src/middleware.ts index af1c823295..8cec8a2645 100644 --- a/autogpt_platform/frontend/src/middleware.ts +++ b/autogpt_platform/frontend/src/middleware.ts @@ -18,6 +18,6 @@ export const config = { * Note: /auth/authorize and /auth/integrations/* ARE protected and need * middleware to run for authentication checks. */ - "/((?!_next/static|_next/image|favicon.ico|auth/callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)", + "/((?!_next/static|_next/image|favicon.ico|auth/callback|auth/integrations/mcp_callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)", ], }; diff --git a/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx b/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx index e47cc65e13..a426d8f667 100644 --- a/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx +++ b/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx @@ -8,6 +8,7 @@ import { HostScopedCredentials, UserPasswordCredentials, } from "@/lib/autogpt-server-api"; +import { postV2ExchangeOauthCodeForMcpTokens } from "@/app/api/__generated__/endpoints/mcp/mcp"; import { useBackendAPI } from "@/lib/autogpt-server-api/context"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { toDisplayName } from "@/providers/agent-credentials/helper"; @@ -38,6 +39,11 @@ export type CredentialsProviderData = { code: string, state_token: string, ) => Promise; + /** MCP-specific OAuth callback that uses dynamic per-server OAuth discovery. */ + mcpOAuthCallback: ( + code: string, + state_token: string, + ) => Promise; createAPIKeyCredentials: ( credentials: APIKeyCredentialsCreatable, ) => Promise; @@ -120,6 +126,35 @@ export default function CredentialsProvider({ [api, addCredentials, onFailToast], ); + /** Exchanges an MCP OAuth code for tokens and adds the result to the internal credentials store. */ + const mcpOAuthCallback = useCallback( + async ( + code: string, + state_token: string, + ): Promise => { + try { + const response = await postV2ExchangeOauthCodeForMcpTokens({ + code, + state_token, + }); + if (response.status !== 200) throw response.data; + const credsMeta: CredentialsMetaResponse = { + ...response.data, + title: response.data.title ?? undefined, + scopes: response.data.scopes ?? undefined, + username: response.data.username ?? undefined, + host: response.data.host ?? undefined, + }; + addCredentials("mcp", credsMeta); + return credsMeta; + } catch (error) { + onFailToast("complete MCP OAuth authentication")(error); + throw error; + } + }, + [addCredentials, onFailToast], + ); + /** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */ const createAPIKeyCredentials = useCallback( async ( @@ -258,6 +293,7 @@ export default function CredentialsProvider({ isSystemProvider: systemProviders.has(provider), oAuthCallback: (code: string, state_token: string) => oAuthCallback(provider, code, state_token), + mcpOAuthCallback, createAPIKeyCredentials: ( credentials: APIKeyCredentialsCreatable, ) => createAPIKeyCredentials(provider, credentials), @@ -286,6 +322,7 @@ export default function CredentialsProvider({ createHostScopedCredentials, deleteCredentials, oAuthCallback, + mcpOAuthCallback, onFailToast, ]); diff --git a/autogpt_platform/frontend/src/tests/pages/build.page.ts b/autogpt_platform/frontend/src/tests/pages/build.page.ts index 9370288f8e..3bb9552b82 100644 --- a/autogpt_platform/frontend/src/tests/pages/build.page.ts +++ b/autogpt_platform/frontend/src/tests/pages/build.page.ts @@ -528,6 +528,9 @@ export class BuildPage extends BasePage { async getBlocksToSkip(): Promise { return [ (await this.getGithubTriggerBlockDetails()).map((b) => b.id), + // MCP Tool block requires an interactive dialog (server URL + OAuth) before + // it can be placed, so it can't be tested via the standard "add block" flow. + "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", ].flat(); } diff --git a/docs/integrations/README.md b/docs/integrations/README.md index a471ef3533..00d4b0c73a 100644 --- a/docs/integrations/README.md +++ b/docs/integrations/README.md @@ -56,12 +56,16 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [File Store](block-integrations/basic.md#file-store) | Downloads and stores a file from a URL, data URI, or local path | | [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value | | [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list | +| [Flatten List](block-integrations/basic.md#flatten-list) | Flattens a nested list structure into a single flat list | | [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering | | [Get Latest Memory](block-integrations/basic.md#get-latest-memory) | Retrieve the latest memory from Mem0 with optional key filtering | | [Get List Item](block-integrations/basic.md#get-list-item) | Returns the element at the given index | | [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store | | [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API | | [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution for human review | +| [Interleave Lists](block-integrations/basic.md#interleave-lists) | Interleaves elements from multiple lists in round-robin fashion, alternating between sources | +| [List Difference](block-integrations/basic.md#list-difference) | Computes the difference between two lists | +| [List Intersection](block-integrations/basic.md#list-intersection) | Computes the intersection of two lists, returning only elements present in both | | [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty | | [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library | | [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes | @@ -84,6 +88,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [Store Value](block-integrations/basic.md#store-value) | A basic block that stores and forwards a value throughout workflows, allowing it to be reused without changes across multiple blocks | | [Universal Type Converter](block-integrations/basic.md#universal-type-converter) | This block is used to convert a value to a universal type | | [XML Parser](block-integrations/basic.md#xml-parser) | Parses XML using gravitasml to tokenize and coverts it to dict | +| [Zip Lists](block-integrations/basic.md#zip-lists) | Zips multiple lists together into a list of grouped elements | ## Data Processing @@ -467,6 +472,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request | | [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository | | [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block | +| [MCP Tool](block-integrations/mcp/block.md#mcp-tool) | Connect to any MCP server and execute its tools | | [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped | ## Media Generation diff --git a/docs/integrations/SUMMARY.md b/docs/integrations/SUMMARY.md index f481ae2e0a..3ad4bf2c6d 100644 --- a/docs/integrations/SUMMARY.md +++ b/docs/integrations/SUMMARY.md @@ -84,6 +84,7 @@ * [Linear Projects](block-integrations/linear/projects.md) * [LLM](block-integrations/llm.md) * [Logic](block-integrations/logic.md) +* [Mcp Block](block-integrations/mcp/block.md) * [Misc](block-integrations/misc.md) * [Notion Create Page](block-integrations/notion/create_page.md) * [Notion Read Database](block-integrations/notion/read_database.md) diff --git a/docs/integrations/block-integrations/basic.md b/docs/integrations/block-integrations/basic.md index 08def38ede..e032690edc 100644 --- a/docs/integrations/block-integrations/basic.md +++ b/docs/integrations/block-integrations/basic.md @@ -637,7 +637,7 @@ This enables extensibility by allowing custom blocks to be added without modifyi ## Concatenate Lists ### What it is -Concatenates multiple lists into a single list. All elements from all input lists are combined in order. +Concatenates multiple lists into a single list. All elements from all input lists are combined in order. Supports optional deduplication and None removal. ### How it works @@ -651,6 +651,8 @@ The block includes validation to ensure each item is actually a list. If a non-l | Input | Description | Type | Required | |-------|-------------|------|----------| | lists | A list of lists to concatenate together. All lists will be combined in order into a single list. | List[List[Any]] | Yes | +| deduplicate | If True, remove duplicate elements from the concatenated result while preserving order. | bool | No | +| remove_none | If True, remove None values from the concatenated result. | bool | No | ### Outputs @@ -658,6 +660,7 @@ The block includes validation to ensure each item is actually a list. If a non-l |--------|-------------|------| | error | Error message if concatenation failed due to invalid input types. | str | | concatenated_list | The concatenated list containing all elements from all input lists in order. | List[Any] | +| length | The total number of elements in the concatenated list. | int | ### Possible use case @@ -820,6 +823,45 @@ This enables conditional logic based on list membership and helps locate items f --- +## Flatten List + +### What it is +Flattens a nested list structure into a single flat list. Supports configurable maximum flattening depth. + +### How it works + +This block recursively traverses a nested list and extracts all leaf elements into a single flat list. You can control how deep the flattening goes with the max_depth parameter: set it to -1 to flatten completely, or to a positive integer to flatten only that many levels. + +The block also reports the original nesting depth of the input, which is useful for understanding the structure of data coming from sources with varying levels of nesting. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| nested_list | A potentially nested list to flatten into a single-level list. | List[Any] | Yes | +| max_depth | Maximum depth to flatten. -1 means flatten completely. 1 means flatten only one level. | int | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if flattening failed. | str | +| flattened_list | The flattened list with all nested elements extracted. | List[Any] | +| length | The number of elements in the flattened list. | int | +| original_depth | The maximum nesting depth of the original input list. | int | + +### Possible use case + +**Normalizing API Responses**: Flatten nested JSON arrays from different API endpoints into a uniform single-level list for consistent processing. + +**Aggregating Nested Results**: Combine results from recursive file searches or nested category trees into a flat list of items for display or export. + +**Data Pipeline Cleanup**: Simplify deeply nested data structures from multiple transformation steps into a clean flat list before final output. + + +--- + ## Get All Memories ### What it is @@ -1012,6 +1054,120 @@ This enables human oversight at critical points in automated workflows, ensuring --- +## Interleave Lists + +### What it is +Interleaves elements from multiple lists in round-robin fashion, alternating between sources. + +### How it works + +This block takes elements from each input list in round-robin order, picking one element from each list in turn. For example, given `[[1, 2, 3], ['a', 'b', 'c']]`, it produces `[1, 'a', 2, 'b', 3, 'c']`. + +When lists have different lengths, shorter lists stop contributing once exhausted, and remaining elements from longer lists continue to be added in order. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| lists | A list of lists to interleave. Elements will be taken in round-robin order. | List[List[Any]] | Yes | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if interleaving failed. | str | +| interleaved_list | The interleaved list with elements alternating from each input list. | List[Any] | +| length | The total number of elements in the interleaved list. | int | + +### Possible use case + +**Balanced Content Mixing**: Alternate between content from different sources (e.g., mixing promotional and organic posts) for a balanced feed. + +**Round-Robin Scheduling**: Distribute tasks evenly across workers or queues by interleaving items from separate task lists. + +**Multi-Language Output**: Weave together translated text segments with their original counterparts for side-by-side comparison. + + +--- + +## List Difference + +### What it is +Computes the difference between two lists. Returns elements in the first list not found in the second, or symmetric difference. + +### How it works + +This block compares two lists and returns elements from list_a that do not appear in list_b. It uses hash-based lookup for efficient comparison. When symmetric mode is enabled, it returns elements that are in either list but not in both. + +The order of elements from list_a is preserved in the output, and elements from list_b are appended when using symmetric difference. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| list_a | The primary list to check elements from. | List[Any] | Yes | +| list_b | The list to subtract. Elements found here will be removed from list_a. | List[Any] | Yes | +| symmetric | If True, compute symmetric difference (elements in either list but not both). | bool | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the operation failed. | str | +| difference | Elements from list_a not found in list_b (or symmetric difference if enabled). | List[Any] | +| length | The number of elements in the difference result. | int | + +### Possible use case + +**Change Detection**: Compare a current list of records against a previous snapshot to find newly added or removed items. + +**Exclusion Filtering**: Remove items from a list that appear in a blocklist or already-processed list to avoid duplicates. + +**Data Sync**: Identify which items exist in one system but not another to determine what needs to be synced. + + +--- + +## List Intersection + +### What it is +Computes the intersection of two lists, returning only elements present in both. + +### How it works + +This block finds elements that appear in both input lists by hashing elements from list_b for efficient lookup, then checking each element of list_a against that set. The output preserves the order from list_a and removes duplicates. + +This is useful for finding common items between two datasets without needing to manually iterate or compare. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| list_a | The first list to intersect. | List[Any] | Yes | +| list_b | The second list to intersect. | List[Any] | Yes | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the operation failed. | str | +| intersection | Elements present in both list_a and list_b. | List[Any] | +| length | The number of elements in the intersection. | int | + +### Possible use case + +**Finding Common Tags**: Identify shared tags or categories between two items for recommendation or grouping purposes. + +**Mutual Connections**: Find users or contacts that appear in both of two different lists, such as shared friends or overlapping team members. + +**Feature Comparison**: Determine which features or capabilities are supported by both of two systems or products. + + +--- + ## List Is Empty ### What it is @@ -1452,3 +1608,42 @@ This makes XML data accessible using standard dictionary operations, allowing yo --- + +## Zip Lists + +### What it is +Zips multiple lists together into a list of grouped elements. Supports padding to longest or truncating to shortest. + +### How it works + +This block pairs up corresponding elements from multiple input lists into sub-lists. For example, zipping `[[1, 2, 3], ['a', 'b', 'c']]` produces `[[1, 'a'], [2, 'b'], [3, 'c']]`. + +By default, the result is truncated to the length of the shortest input list. Enable pad_to_longest to instead pad shorter lists with a fill_value so no elements from longer lists are lost. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| lists | A list of lists to zip together. Corresponding elements will be grouped. | List[List[Any]] | Yes | +| pad_to_longest | If True, pad shorter lists with fill_value to match the longest list. If False, truncate to shortest. | bool | No | +| fill_value | Value to use for padding when pad_to_longest is True. | Fill Value | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if zipping failed. | str | +| zipped_list | The zipped list of grouped elements. | List[List[Any]] | +| length | The number of groups in the zipped result. | int | + +### Possible use case + +**Creating Key-Value Pairs**: Combine a list of field names with a list of values to build structured records or dictionaries. + +**Parallel Data Alignment**: Pair up corresponding items from separate data sources (e.g., names and email addresses) for processing together. + +**Table Row Construction**: Group column data into rows by zipping each column's values together for CSV export or display. + + +--- diff --git a/docs/integrations/block-integrations/mcp/block.md b/docs/integrations/block-integrations/mcp/block.md new file mode 100644 index 0000000000..6858e42e94 --- /dev/null +++ b/docs/integrations/block-integrations/mcp/block.md @@ -0,0 +1,40 @@ +# Mcp Block + +Blocks for connecting to and executing tools on MCP (Model Context Protocol) servers. + + +## MCP Tool + +### What it is +Connect to any MCP server and execute its tools. Provide a server URL, select a tool, and pass arguments dynamically. + +### How it works + +The block uses JSON-RPC 2.0 over HTTP to communicate with MCP servers. When configuring, it sends an `initialize` request followed by `tools/list` to discover available tools and their input schemas. On execution, it calls `tools/call` with the selected tool name and arguments, then extracts text, image, or resource content from the response. + +Authentication is handled via OAuth 2.0 when the server requires it. The block supports optional credentials — public servers work without authentication, while protected servers trigger a standard OAuth flow with PKCE. Tokens are automatically refreshed when they expire. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| server_url | URL of the MCP server (Streamable HTTP endpoint) | str | Yes | +| selected_tool | The MCP tool to execute | str | No | +| tool_arguments | Arguments to pass to the selected MCP tool. The fields here are defined by the tool's input schema. | Dict[str, Any] | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the tool call failed | str | +| result | The result returned by the MCP tool | Result | + +### Possible use case + +- **Connecting to third-party APIs**: Use an MCP server like Sentry or Linear to query issues, create tickets, or manage projects without building custom integrations. +- **AI-powered tool execution**: Chain MCP tool calls with AI blocks to let agents dynamically discover and use external tools based on task requirements. +- **Data retrieval from knowledge bases**: Connect to MCP servers like DeepWiki to search documentation, retrieve code context, or query structured knowledge bases. + + +--- diff --git a/plans/SECRT-1950-claude-ci-optimizations.md b/plans/SECRT-1950-claude-ci-optimizations.md new file mode 100644 index 0000000000..15d1419b0e --- /dev/null +++ b/plans/SECRT-1950-claude-ci-optimizations.md @@ -0,0 +1,165 @@ +# Implementation Plan: SECRT-1950 - Apply E2E CI Optimizations to Claude Code Workflows + +## Ticket +[SECRT-1950](https://linear.app/autogpt/issue/SECRT-1950) + +## Summary +Apply Pwuts's CI performance optimizations from PR #12090 to Claude Code workflows. + +## Reference PR +https://github.com/Significant-Gravitas/AutoGPT/pull/12090 + +--- + +## Analysis + +### Current State (claude.yml) + +**pnpm caching (lines 104-118):** +```yaml +- name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: "22" + +- name: Enable corepack + run: corepack enable + +- name: Set pnpm store directory + run: | + pnpm config set store-dir ~/.pnpm-store + echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV + +- name: Cache frontend dependencies + uses: actions/cache@v5 + with: + path: ~/.pnpm-store + key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} + restore-keys: | + ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} + ${{ runner.os }}-pnpm- +``` + +**Docker setup (lines 134-165):** +- Uses `docker-buildx-action@v3` +- Has manual Docker image caching via `actions/cache` +- Runs `docker compose up` without buildx bake optimization + +### Pwuts's Optimizations (PR #12090) + +1. **Simplified pnpm caching** - Use `setup-node` built-in cache: +```yaml +- name: Enable corepack + run: corepack enable + +- name: Set up Node + uses: actions/setup-node@v6 + with: + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml +``` + +2. **Docker build caching via buildx bake**: +```yaml +- name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver: docker-container + driver-opts: network=host + +- name: Expose GHA cache to docker buildx CLI + uses: crazy-max/ghaction-github-runtime@v3 + +- name: Build Docker images (with cache) + run: | + pip install pyyaml + docker compose -f docker-compose.yml config > docker-compose.resolved.yml + python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \ + --source docker-compose.resolved.yml \ + --cache-from "type=gha" \ + --cache-to "type=gha,mode=max" \ + ... + docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load +``` + +--- + +## Proposed Changes + +### 1. Update pnpm caching in `claude.yml` + +**Before:** +- Manual cache key generation +- Separate `actions/cache` step +- Manual pnpm store directory config + +**After:** +- Use `setup-node` built-in `cache: "pnpm"` option +- Remove manual cache step +- Keep `corepack enable` before `setup-node` + +### 2. Update Docker build in `claude.yml` + +**Before:** +- Manual Docker layer caching via `actions/cache` with `/tmp/.buildx-cache` +- Simple `docker compose build` + +**After:** +- Use `crazy-max/ghaction-github-runtime@v3` to expose GHA cache +- Use `docker-ci-fix-compose-build-cache.py` script +- Build with `docker buildx bake` + +### 3. Apply same changes to other Claude workflows + +- `claude-dependabot.yml` - Check if it has similar patterns +- `claude-ci-failure-auto-fix.yml` - Check if it has similar patterns +- `copilot-setup-steps.yml` - Reusable workflow, may be the source of truth + +--- + +## Files to Modify + +1. `.github/workflows/claude.yml` +2. `.github/workflows/claude-dependabot.yml` (if applicable) +3. `.github/workflows/claude-ci-failure-auto-fix.yml` (if applicable) + +## Dependencies + +- PR #12090 must be merged first (provides the `docker-ci-fix-compose-build-cache.py` script) +- Backend Dockerfile optimizations (already in PR #12090) + +--- + +## Test Plan + +1. Create PR with changes +2. Trigger Claude workflow manually or via `@claude` mention on a test issue +3. Compare CI runtime before/after +4. Verify Claude agent still works correctly (can checkout, build, run tests) + +--- + +## Risk Assessment + +**Low risk:** +- These are CI infrastructure changes, not code changes +- If caching fails, builds fall back to uncached (slower but works) +- Changes mirror proven patterns from PR #12090 + +--- + +## Questions for Reviewer + +1. Should we wait for PR #12090 to merge before creating this PR? +2. Does `copilot-setup-steps.yml` need updating, or is it a separate concern? +3. Any concerns about cache key collisions between frontend E2E and Claude workflows? + +--- + +## Verified + +- ✅ **`claude-dependabot.yml`**: Has same pnpm caching pattern as `claude.yml` (manual `actions/cache`) — NEEDS UPDATE +- ✅ **`claude-ci-failure-auto-fix.yml`**: Simple workflow with no pnpm or Docker caching — NO CHANGES NEEDED +- ✅ **Script path**: `docker-ci-fix-compose-build-cache.py` will be at `.github/workflows/scripts/` after PR #12090 merges +- ✅ **Test seed caching**: NOT APPLICABLE — Claude workflows spin up a dev environment but don't run E2E tests with pre-seeded data. The seed caching in PR #12090 is specific to the frontend E2E test suite which needs consistent test data. Claude just needs the services running.