diff --git a/.github/scripts/detect_overlaps.py b/.github/scripts/detect_overlaps.py new file mode 100644 index 0000000000..1f9f4be7cf --- /dev/null +++ b/.github/scripts/detect_overlaps.py @@ -0,0 +1,1229 @@ +#!/usr/bin/env python3 +""" +PR Overlap Detection Tool + +Detects potential merge conflicts between a given PR and other open PRs +by checking for file overlap, line overlap, and actual merge conflicts. +""" + +import json +import os +import re +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from typing import Optional + + +# ============================================================================= +# MAIN ENTRY POINT +# ============================================================================= + +def main(): + """Main entry point for PR overlap detection.""" + import argparse + + parser = argparse.ArgumentParser(description="Detect PR overlaps and potential merge conflicts") + parser.add_argument("pr_number", type=int, help="PR number to check") + parser.add_argument("--base", default=None, help="Base branch (default: auto-detect from PR)") + parser.add_argument("--skip-merge-test", action="store_true", help="Skip actual merge conflict testing") + parser.add_argument("--discord-webhook", default=os.environ.get("DISCORD_WEBHOOK_URL"), help="Discord webhook URL for notifications") + parser.add_argument("--dry-run", action="store_true", help="Don't post comments, just print") + + args = parser.parse_args() + + owner, repo = get_repo_info() + print(f"Checking PR #{args.pr_number} in {owner}/{repo}") + + # Get current PR info + current_pr = fetch_pr_details(args.pr_number) + base_branch = args.base or current_pr.base_ref + + print(f"PR #{current_pr.number}: {current_pr.title}") + print(f"Base branch: {base_branch}") + print(f"Files changed: {len(current_pr.files)}") + + # Find overlapping PRs + overlaps, all_changes = find_overlapping_prs( + owner, repo, base_branch, current_pr, args.pr_number, args.skip_merge_test + ) + + if not overlaps: + print("No overlaps detected!") + return + + # Generate and post report + comment = format_comment(overlaps, args.pr_number, current_pr.changed_ranges, all_changes) + + if args.dry_run: + print("\n" + "="*60) + print("COMMENT PREVIEW:") + print("="*60) + print(comment) + else: + if comment: + post_or_update_comment(args.pr_number, comment) + print("Posted comment to PR") + + if args.discord_webhook: + send_discord_notification(args.discord_webhook, current_pr, overlaps) + + # Report results and exit + report_results(overlaps) + + +# ============================================================================= +# HIGH-LEVEL WORKFLOW FUNCTIONS +# ============================================================================= + +def fetch_pr_details(pr_number: int) -> "PullRequest": + """Fetch details for a specific PR including its diff.""" + result = run_gh(["pr", "view", str(pr_number), "--json", "number,title,url,author,headRefName,baseRefName,files"]) + data = json.loads(result.stdout) + + pr = PullRequest( + number=data["number"], + title=data["title"], + author=data["author"]["login"] if data.get("author") else "unknown", + url=data["url"], + head_ref=data["headRefName"], + base_ref=data["baseRefName"], + files=[f["path"] for f in data["files"]], + changed_ranges={} + ) + + # Get detailed diff + diff = get_pr_diff(pr_number) + pr.changed_ranges = parse_diff_ranges(diff) + + return pr + + +def find_overlapping_prs( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + current_pr_number: int, + skip_merge_test: bool +) -> tuple[list["Overlap"], dict[int, dict[str, "ChangedFile"]]]: + """Find all PRs that overlap with the current PR.""" + # Query other open PRs + all_prs = query_open_prs(owner, repo, base_branch) + other_prs = [p for p in all_prs if p["number"] != current_pr_number] + + print(f"Found {len(other_prs)} other open PRs targeting {base_branch}") + + # Find file overlaps (excluding ignored files, filtering by age) + candidates = find_file_overlap_candidates(current_pr.files, other_prs) + + print(f"Found {len(candidates)} PRs with file overlap (excluding ignored files)") + + if not candidates: + return [], {} + + # First pass: analyze line overlaps (no merge testing yet) + overlaps = [] + all_changes = {} + prs_needing_merge_test = [] + + for pr_data, shared_files in candidates: + overlap, pr_changes = analyze_pr_overlap( + owner, repo, base_branch, current_pr, pr_data, shared_files, + skip_merge_test=True # Always skip in first pass + ) + if overlap: + overlaps.append(overlap) + all_changes[pr_data["number"]] = pr_changes + # Track PRs that need merge testing + if overlap.line_overlaps and not skip_merge_test: + prs_needing_merge_test.append(overlap) + + # Second pass: batch merge testing with shared clone + if prs_needing_merge_test: + run_batch_merge_tests(owner, repo, base_branch, current_pr, prs_needing_merge_test) + + return overlaps, all_changes + + +def run_batch_merge_tests( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + overlaps: list["Overlap"] +): + """Run merge tests for multiple PRs using a shared clone.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone once + if not clone_repo(owner, repo, base_branch, tmpdir): + return + + configure_git(tmpdir) + + # Fetch current PR branch once + result = run_git(["fetch", "origin", f"pull/{current_pr.number}/head:pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch current PR #{current_pr.number}", file=sys.stderr) + return + + for overlap in overlaps: + other_pr = overlap.pr_b if overlap.pr_a.number == current_pr.number else overlap.pr_a + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + + # Clean up any in-progress merge from previous iteration + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + # Reset to base branch + run_git(["checkout", base_branch], cwd=tmpdir, check=False) + run_git(["reset", "--hard", f"origin/{base_branch}"], cwd=tmpdir, check=False) + run_git(["clean", "-fdx"], cwd=tmpdir, check=False) + + # Fetch the other PR branch + result = run_git(["fetch", "origin", f"pull/{other_pr.number}/head:pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{other_pr.number}: {result.stderr.strip()}", file=sys.stderr) + continue + + # Try merging current PR first + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Current PR conflicts with base + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'pr_a_conflicts_base' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + continue + + # Commit and try merging other PR + run_git(["commit", "-m", f"Merge PR #{current_pr.number}"], cwd=tmpdir, check=False) + + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Conflict between PRs + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'conflict' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + +def analyze_pr_overlap( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + other_pr_data: dict, + shared_files: list[str], + skip_merge_test: bool +) -> tuple[Optional["Overlap"], dict[str, "ChangedFile"]]: + """Analyze overlap between current PR and another PR.""" + # Filter out ignored files + non_ignored_shared = [f for f in shared_files if not should_ignore_file(f)] + if not non_ignored_shared: + return None, {} + + other_pr = PullRequest( + number=other_pr_data["number"], + title=other_pr_data["title"], + author=other_pr_data["author"], + url=other_pr_data["url"], + head_ref=other_pr_data["head_ref"], + base_ref=other_pr_data["base_ref"], + files=other_pr_data["files"], + changed_ranges={}, + updated_at=other_pr_data.get("updated_at") + ) + + # Get diff for other PR + other_diff = get_pr_diff(other_pr.number) + other_pr.changed_ranges = parse_diff_ranges(other_diff) + + # Check line overlaps + line_overlaps = find_line_overlaps( + current_pr.changed_ranges, + other_pr.changed_ranges, + shared_files + ) + + overlap = Overlap( + pr_a=current_pr, + pr_b=other_pr, + overlapping_files=non_ignored_shared, + line_overlaps=line_overlaps + ) + + # Test for actual merge conflicts if we have line overlaps + if line_overlaps and not skip_merge_test: + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + has_conflict, conflict_files, conflict_details, error_type = test_merge_conflict( + owner, repo, base_branch, current_pr, other_pr + ) + overlap.has_merge_conflict = has_conflict + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = error_type + + return overlap, other_pr.changed_ranges + + +def find_file_overlap_candidates( + current_files: list[str], + other_prs: list[dict], + max_age_days: int = 14 +) -> list[tuple[dict, list[str]]]: + """Find PRs that share files with the current PR.""" + from datetime import datetime, timezone, timedelta + + current_files_set = set(f for f in current_files if not should_ignore_file(f)) + candidates = [] + cutoff_date = datetime.now(timezone.utc) - timedelta(days=max_age_days) + + for pr_data in other_prs: + # Filter out PRs older than max_age_days + updated_at = pr_data.get("updated_at") + if updated_at: + try: + pr_date = datetime.fromisoformat(updated_at.replace('Z', '+00:00')) + if pr_date < cutoff_date: + continue # Skip old PRs + except Exception as e: + # If we can't parse date, include the PR (safe fallback) + print(f"Warning: Could not parse date for PR: {e}", file=sys.stderr) + + other_files = set(f for f in pr_data["files"] if not should_ignore_file(f)) + shared = current_files_set & other_files + + if shared: + candidates.append((pr_data, list(shared))) + + return candidates + + +def report_results(overlaps: list["Overlap"]): + """Report results (informational only, always exits 0).""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if conflicts: + print(f"\n⚠️ Found {len(conflicts)} merge conflict(s)") + + line_overlap_count = len([o for o in overlaps if o.line_overlaps]) + if line_overlap_count: + print(f"\n⚠️ Found {line_overlap_count} PR(s) with line overlap") + + print("\n✅ Done") + # Always exit 0 - this check is informational, not a merge blocker + + +# ============================================================================= +# COMMENT FORMATTING +# ============================================================================= + +def format_comment( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict[str, "ChangedFile"], + all_changes: dict[int, dict[str, "ChangedFile"]] +) -> str: + """Format the overlap report as a PR comment.""" + if not overlaps: + return "" + + lines = ["## 🔍 PR Overlap Detection"] + lines.append("") + lines.append("This check compares your PR against all other open PRs targeting the same branch to detect potential merge conflicts early.") + lines.append("") + + # Check if current PR conflicts with base branch + format_base_conflicts(overlaps, lines) + + # Classify and sort overlaps + classified = classify_all_overlaps(overlaps, current_pr, changes_current, all_changes) + + # Group by risk + conflicts = [(o, r) for o, r in classified if r == 'conflict'] + medium_risk = [(o, r) for o, r in classified if r == 'medium'] + low_risk = [(o, r) for o, r in classified if r == 'low'] + + # Format each section + format_conflicts_section(conflicts, current_pr, lines) + format_medium_risk_section(medium_risk, current_pr, changes_current, all_changes, lines) + format_low_risk_section(low_risk, current_pr, lines) + + # Summary + total = len(overlaps) + lines.append(f"\n**Summary:** {len(conflicts)} conflict(s), {len(medium_risk)} medium risk, {len(low_risk)} low risk (out of {total} PRs with file overlap)") + lines.append("\n---\n*Auto-generated on push. Ignores: `openapi.json`, lock files.*") + + return "\n".join(lines) + + +def format_base_conflicts(overlaps: list["Overlap"], lines: list[str]): + """Format base branch conflicts section.""" + base_conflicts = [o for o in overlaps if o.conflict_type == 'pr_a_conflicts_base'] + if base_conflicts: + lines.append("### ⚠️ This PR has conflicts with the base branch\n") + lines.append("Conflicts will need to be resolved before merging:\n") + first = base_conflicts[0] + for f in first.conflict_files[:10]: + lines.append(f"- `{f}`") + if len(first.conflict_files) > 10: + lines.append(f"- ... and {len(first.conflict_files) - 10} more files") + lines.append("\n") + + +def format_conflicts_section(conflicts: list[tuple], current_pr: int, lines: list[str]): + """Format the merge conflicts section.""" + pr_conflicts = [(o, r) for o, r in conflicts if o.conflict_type != 'pr_a_conflicts_base'] + + if not pr_conflicts: + return + + lines.append("### 🔴 Merge Conflicts Detected") + lines.append("") + lines.append("The following PRs have been tested and **will have merge conflicts** if merged after this PR. Consider coordinating with the authors.") + lines.append("") + + for o, _ in pr_conflicts: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + format_pr_entry(other, lines) + format_conflict_details(o, lines) + lines.append("") + + +def format_medium_risk_section( + medium_risk: list[tuple], + current_pr: int, + changes_current: dict, + all_changes: dict, + lines: list[str] +): + """Format the medium risk section.""" + if not medium_risk: + return + + lines.append("### 🟡 Medium Risk — Some Line Overlap\n") + lines.append("These PRs have some overlapping changes:\n") + + for o, _ in medium_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other.number, {}) + format_pr_entry(other, lines) + + # Note if rename is involved + for file_path in o.overlapping_files: + file_a = changes_current.get(file_path) + file_b = other_changes.get(file_path) + if (file_a and file_a.is_rename) or (file_b and file_b.is_rename): + lines.append(f" - ⚠️ `{file_path}` is being renamed/moved") + break + + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") + + +def format_low_risk_section(low_risk: list[tuple], current_pr: int, lines: list[str]): + """Format the low risk section.""" + if not low_risk: + return + + lines.append("### 🟢 Low Risk — File Overlap Only\n") + lines.append("
These PRs touch the same files but different sections (click to expand)\n") + + for o, _ in low_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + format_pr_entry(other, lines) + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") # Add blank line between entries + + lines.append("
\n") + + +def format_pr_entry(pr: "PullRequest", lines: list[str]): + """Format a single PR entry line.""" + updated = format_relative_time(pr.updated_at) + updated_str = f" · updated {updated}" if updated else "" + # Just use #number - GitHub auto-renders it with title + lines.append(f"- #{pr.number} ({pr.author}{updated_str})") + + +def format_conflict_details(overlap: "Overlap", lines: list[str]): + """Format conflict details for a PR.""" + if overlap.conflict_details: + all_paths = [d.path for d in overlap.conflict_details] + common_prefix = find_common_prefix(all_paths) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for detail in overlap.conflict_details: + display_path = detail.path[len(common_prefix):] if common_prefix else detail.path + size_str = format_conflict_size(detail) + lines.append(f" - `{display_path}`{size_str}") + elif overlap.conflict_files: + common_prefix = find_common_prefix(overlap.conflict_files) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for f in overlap.conflict_files: + display_path = f[len(common_prefix):] if common_prefix else f + lines.append(f" - `{display_path}`") + + +def format_conflict_size(detail: "ConflictInfo") -> str: + """Format conflict size string for a file.""" + if detail.conflict_count > 0: + return f" ({detail.conflict_count} conflict{'s' if detail.conflict_count > 1 else ''}, ~{detail.conflict_lines} lines)" + elif detail.conflict_type != 'content': + type_labels = { + 'both_added': 'added in both', + 'both_deleted': 'deleted in both', + 'deleted_by_us': 'deleted here, modified there', + 'deleted_by_them': 'modified here, deleted there', + 'added_by_us': 'added here', + 'added_by_them': 'added there', + } + label = type_labels.get(detail.conflict_type, detail.conflict_type) + return f" ({label})" + return "" + + +def format_line_overlaps(line_overlaps: dict[str, list[tuple]], lines: list[str]): + """Format line overlap details.""" + all_paths = list(line_overlaps.keys()) + common_prefix = find_common_prefix(all_paths) if len(all_paths) > 1 else "" + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for file_path, ranges in line_overlaps.items(): + display_path = file_path[len(common_prefix):] if common_prefix else file_path + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + indent = " " if common_prefix else " " + lines.append(f"{indent}- `{display_path}`: {', '.join(range_strs)}") + + +# ============================================================================= +# OVERLAP ANALYSIS +# ============================================================================= + +def classify_all_overlaps( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict, + all_changes: dict +) -> list[tuple["Overlap", str]]: + """Classify all overlaps by risk level and sort them.""" + classified = [] + for o in overlaps: + other_pr = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other_pr.number, {}) + risk = classify_overlap_risk(o, changes_current, other_changes) + classified.append((o, risk)) + + def sort_key(item): + o, risk = item + risk_order = {'conflict': 0, 'medium': 1, 'low': 2} + # For conflicts, also sort by total conflict lines (descending) + conflict_lines = sum(d.conflict_lines for d in o.conflict_details) if o.conflict_details else 0 + return (risk_order.get(risk, 99), -conflict_lines) + + classified.sort(key=sort_key) + + return classified + + +def classify_overlap_risk( + overlap: "Overlap", + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"] +) -> str: + """Classify the risk level of an overlap.""" + if overlap.has_merge_conflict: + return 'conflict' + + has_rename = any( + (changes_a.get(f) and changes_a[f].is_rename) or + (changes_b.get(f) and changes_b[f].is_rename) + for f in overlap.overlapping_files + ) + + if overlap.line_overlaps: + total_overlap_lines = sum( + end - start + 1 + for ranges in overlap.line_overlaps.values() + for start, end in ranges + ) + + # Medium risk: >20 lines overlap or file rename + if total_overlap_lines > 20 or has_rename: + return 'medium' + else: + return 'low' + + if has_rename: + return 'medium' + + return 'low' + + +def find_line_overlaps( + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"], + shared_files: list[str] +) -> dict[str, list[tuple[int, int]]]: + """Find overlapping line ranges in shared files.""" + overlaps = {} + + for file_path in shared_files: + if should_ignore_file(file_path): + continue + + file_a = changes_a.get(file_path) + file_b = changes_b.get(file_path) + + if not file_a or not file_b: + continue + + # Skip pure renames + if file_a.is_rename and not file_a.additions and not file_a.deletions: + continue + if file_b.is_rename and not file_b.additions and not file_b.deletions: + continue + + # Note: This mixes old-file (deletions) and new-file (additions) line numbers, + # which can cause false positives when PRs insert/remove many lines. + # Acceptable for v1 since the real merge test is the authoritative check. + file_overlaps = find_range_overlaps( + file_a.additions + file_a.deletions, + file_b.additions + file_b.deletions + ) + + if file_overlaps: + overlaps[file_path] = merge_ranges(file_overlaps) + + return overlaps + + +def find_range_overlaps( + ranges_a: list[tuple[int, int]], + ranges_b: list[tuple[int, int]] +) -> list[tuple[int, int]]: + """Find overlapping regions between two sets of ranges.""" + overlaps = [] + for range_a in ranges_a: + for range_b in ranges_b: + if ranges_overlap(range_a, range_b): + overlap_start = max(range_a[0], range_b[0]) + overlap_end = min(range_a[1], range_b[1]) + overlaps.append((overlap_start, overlap_end)) + return overlaps + + +def ranges_overlap(range_a: tuple[int, int], range_b: tuple[int, int]) -> bool: + """Check if two line ranges overlap.""" + return range_a[0] <= range_b[1] and range_b[0] <= range_a[1] + + +def merge_ranges(ranges: list[tuple[int, int]]) -> list[tuple[int, int]]: + """Merge overlapping line ranges.""" + if not ranges: + return [] + + sorted_ranges = sorted(ranges, key=lambda x: x[0]) + merged = [sorted_ranges[0]] + + for current in sorted_ranges[1:]: + last = merged[-1] + if current[0] <= last[1] + 1: + merged[-1] = (last[0], max(last[1], current[1])) + else: + merged.append(current) + + return merged + + +# ============================================================================= +# MERGE CONFLICT TESTING +# ============================================================================= + +def test_merge_conflict( + owner: str, + repo: str, + base_branch: str, + pr_a: "PullRequest", + pr_b: "PullRequest" +) -> tuple[bool, list[str], list["ConflictInfo"], str]: + """Test if merging both PRs would cause a conflict.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone repo + if not clone_repo(owner, repo, base_branch, tmpdir): + return False, [], [], None + + configure_git(tmpdir) + if not fetch_pr_branches(tmpdir, pr_a.number, pr_b.number): + # Fetch failed for one or both PRs - can't test merge + return False, [], [], None + + # Try merging PR A first + conflict_result = try_merge_pr(tmpdir, pr_a.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'pr_a_conflicts_base' + + # Commit and try merging PR B + run_git(["commit", "-m", f"Merge PR #{pr_a.number}"], cwd=tmpdir, check=False) + + conflict_result = try_merge_pr(tmpdir, pr_b.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'conflict' + + return False, [], [], None + + +def clone_repo(owner: str, repo: str, branch: str, tmpdir: str) -> bool: + """Clone the repository.""" + clone_url = f"https://github.com/{owner}/{repo}.git" + result = run_git( + ["clone", "--depth=50", "--branch", branch, clone_url, tmpdir], + check=False + ) + if result.returncode != 0: + print(f"Failed to clone: {result.stderr}", file=sys.stderr) + return False + return True + + +def configure_git(tmpdir: str): + """Configure git for commits.""" + run_git(["config", "user.email", "github-actions[bot]@users.noreply.github.com"], cwd=tmpdir, check=False) + run_git(["config", "user.name", "github-actions[bot]"], cwd=tmpdir, check=False) + + +def fetch_pr_branches(tmpdir: str, pr_a: int, pr_b: int) -> bool: + """Fetch both PR branches. Returns False if any fetch fails.""" + success = True + for pr_num in (pr_a, pr_b): + result = run_git(["fetch", "origin", f"pull/{pr_num}/head:pr-{pr_num}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{pr_num}: {result.stderr.strip()}", file=sys.stderr) + success = False + return success + + +def try_merge_pr(tmpdir: str, pr_number: int) -> Optional[tuple[list[str], list["ConflictInfo"]]]: + """Try to merge a PR. Returns conflict info if conflicts, None if success.""" + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{pr_number}"], cwd=tmpdir, check=False) + + if result.returncode == 0: + return None + + # Conflict detected + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + return conflict_files, conflict_details + + +def extract_conflict_info(tmpdir: str, stderr: str) -> tuple[list[str], list["ConflictInfo"]]: + """Extract conflict information from git status.""" + status_result = run_git(["status", "--porcelain"], cwd=tmpdir, check=False) + + status_types = { + 'UU': 'content', + 'AA': 'both_added', + 'DD': 'both_deleted', + 'DU': 'deleted_by_us', + 'UD': 'deleted_by_them', + 'AU': 'added_by_us', + 'UA': 'added_by_them', + } + + conflict_files = [] + conflict_details = [] + + for line in status_result.stdout.split("\n"): + if len(line) >= 3 and line[0:2] in status_types: + status_code = line[0:2] + file_path = line[3:].strip() + conflict_files.append(file_path) + + info = analyze_conflict_markers(file_path, tmpdir) + info.conflict_type = status_types.get(status_code, 'unknown') + conflict_details.append(info) + + # Fallback to stderr parsing + if not conflict_files and stderr: + for line in stderr.split("\n"): + if "CONFLICT" in line and ":" in line: + parts = line.split(":") + if len(parts) > 1: + file_part = parts[-1].strip() + if file_part and not file_part.startswith("Merge"): + conflict_files.append(file_part) + conflict_details.append(ConflictInfo(path=file_part)) + + return conflict_files, conflict_details + + +def analyze_conflict_markers(file_path: str, cwd: str) -> "ConflictInfo": + """Analyze a conflicted file to count conflict regions and lines.""" + info = ConflictInfo(path=file_path) + + try: + full_path = os.path.join(cwd, file_path) + with open(full_path, 'r', errors='ignore') as f: + content = f.read() + + in_conflict = False + current_conflict_lines = 0 + + for line in content.split('\n'): + if line.startswith('<<<<<<<'): + in_conflict = True + info.conflict_count += 1 + current_conflict_lines = 1 + elif line.startswith('>>>>>>>'): + in_conflict = False + current_conflict_lines += 1 + info.conflict_lines += current_conflict_lines + elif in_conflict: + current_conflict_lines += 1 + except Exception as e: + print(f"Warning: Could not analyze conflict markers in {file_path}: {e}", file=sys.stderr) + + return info + + +# ============================================================================= +# DIFF PARSING +# ============================================================================= + +def parse_diff_ranges(diff: str) -> dict[str, "ChangedFile"]: + """Parse a unified diff and extract changed line ranges per file.""" + files = {} + current_file = None + pending_rename_from = None + is_rename = False + + for line in diff.split("\n"): + # Reset rename state on new file diff header + if line.startswith("diff --git "): + is_rename = False + pending_rename_from = None + elif line.startswith("rename from "): + pending_rename_from = line[12:] + is_rename = True + elif line.startswith("rename to "): + pass # rename target is captured via "+++ b/" line + elif line.startswith("similarity index"): + is_rename = True + elif line.startswith("+++ b/"): + path = line[6:] + current_file = ChangedFile( + path=path, + additions=[], + deletions=[], + is_rename=is_rename, + old_path=pending_rename_from + ) + files[path] = current_file + pending_rename_from = None + is_rename = False + elif line.startswith("--- /dev/null"): + is_rename = False + pending_rename_from = None + elif line.startswith("@@") and current_file: + parse_hunk_header(line, current_file) + + return files + + +def parse_hunk_header(line: str, current_file: "ChangedFile"): + """Parse a diff hunk header and add ranges to the file.""" + match = re.match(r"@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@", line) + if match: + old_start = int(match.group(1)) + old_count = int(match.group(2) or 1) + new_start = int(match.group(3)) + new_count = int(match.group(4) or 1) + + if old_count > 0: + current_file.deletions.append((old_start, old_start + old_count - 1)) + if new_count > 0: + current_file.additions.append((new_start, new_start + new_count - 1)) + + +# ============================================================================= +# GITHUB API +# ============================================================================= + +def get_repo_info() -> tuple[str, str]: + """Get owner and repo name from environment or git.""" + if os.environ.get("GITHUB_REPOSITORY"): + owner, repo = os.environ["GITHUB_REPOSITORY"].split("/") + return owner, repo + + result = run_gh(["repo", "view", "--json", "owner,name"]) + data = json.loads(result.stdout) + return data["owner"]["login"], data["name"] + + +def query_open_prs(owner: str, repo: str, base_branch: str) -> list[dict]: + """Query all open PRs targeting the specified base branch.""" + prs = [] + cursor = None + + while True: + after_clause = f', after: "{cursor}"' if cursor else "" + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequests( + first: 100{after_clause}, + states: OPEN, + baseRefName: "{base_branch}", + orderBy: {{field: UPDATED_AT, direction: DESC}} + ) {{ + totalCount + edges {{ + node {{ + number + title + url + updatedAt + author {{ login }} + headRefName + baseRefName + files(first: 100) {{ + nodes {{ path }} + pageInfo {{ hasNextPage }} + }} + }} + }} + pageInfo {{ + endCursor + hasNextPage + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"]) + data = json.loads(result.stdout) + + if "errors" in data: + print(f"GraphQL errors: {data['errors']}", file=sys.stderr) + sys.exit(1) + + pr_data = data["data"]["repository"]["pullRequests"] + for edge in pr_data["edges"]: + node = edge["node"] + files_data = node["files"] + # Warn if PR has more than 100 files (API limit, we only fetch first 100) + if files_data.get("pageInfo", {}).get("hasNextPage"): + print(f"Warning: PR #{node['number']} has >100 files, overlap detection may be incomplete", file=sys.stderr) + prs.append({ + "number": node["number"], + "title": node["title"], + "url": node["url"], + "updated_at": node.get("updatedAt"), + "author": node["author"]["login"] if node["author"] else "unknown", + "head_ref": node["headRefName"], + "base_ref": node["baseRefName"], + "files": [f["path"] for f in files_data["nodes"]] + }) + + if not pr_data["pageInfo"]["hasNextPage"]: + break + cursor = pr_data["pageInfo"]["endCursor"] + + return prs + + +def get_pr_diff(pr_number: int) -> str: + """Get the diff for a PR.""" + result = run_gh(["pr", "diff", str(pr_number)]) + return result.stdout + + +def post_or_update_comment(pr_number: int, body: str): + """Post a new comment or update existing overlap detection comment.""" + if not body: + return + + marker = "## 🔍 PR Overlap Detection" + + # Find existing comment using GraphQL + owner, repo = get_repo_info() + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequest(number: {pr_number}) {{ + comments(first: 100) {{ + nodes {{ + id + body + author {{ login }} + }} + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"], check=False) + + existing_comment_id = None + if result.returncode == 0: + try: + data = json.loads(result.stdout) + comments = data.get("data", {}).get("repository", {}).get("pullRequest", {}).get("comments", {}).get("nodes", []) + for comment in comments: + if marker in comment.get("body", ""): + existing_comment_id = comment["id"] + break + except Exception as e: + print(f"Warning: Could not search for existing comment: {e}", file=sys.stderr) + + if existing_comment_id: + # Update existing comment using GraphQL mutation + # Use json.dumps for proper escaping of all special characters + escaped_body = json.dumps(body)[1:-1] # Strip outer quotes added by json.dumps + mutation = f''' + mutation {{ + updateIssueComment(input: {{id: "{existing_comment_id}", body: "{escaped_body}"}}) {{ + issueComment {{ id }} + }} + }} + ''' + result = run_gh(["api", "graphql", "-f", f"query={mutation}"], check=False) + if result.returncode == 0: + print(f"Updated existing overlap comment") + else: + # Fallback to posting new comment + print(f"Failed to update comment, posting new one: {result.stderr}", file=sys.stderr) + run_gh(["pr", "comment", str(pr_number), "--body", body]) + else: + # Post new comment + run_gh(["pr", "comment", str(pr_number), "--body", body]) + + +def send_discord_notification(webhook_url: str, pr: "PullRequest", overlaps: list["Overlap"]): + """Send a Discord notification about significant overlaps.""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if not conflicts: + return + + # Discord limits: max 25 fields, max 1024 chars per field value + fields = [] + for o in conflicts[:25]: + other = o.pr_b if o.pr_a.number == pr.number else o.pr_a + # Build value string with truncation to stay under 1024 chars + file_list = o.conflict_files[:3] + files_str = f"Files: `{'`, `'.join(file_list)}`" + if len(o.conflict_files) > 3: + files_str += f" (+{len(o.conflict_files) - 3} more)" + value = f"[{other.title[:100]}]({other.url})\n{files_str}" + # Truncate if still too long + if len(value) > 1024: + value = value[:1020] + "..." + fields.append({ + "name": f"Conflicts with #{other.number}", + "value": value, + "inline": False + }) + + embed = { + "title": f"⚠️ PR #{pr.number} has merge conflicts", + "description": f"[{pr.title}]({pr.url})", + "color": 0xFF0000, + "fields": fields + } + + if len(conflicts) > 25: + embed["footer"] = {"text": f"... and {len(conflicts) - 25} more conflicts"} + + try: + subprocess.run( + ["curl", "-X", "POST", "-H", "Content-Type: application/json", + "--max-time", "10", + "-d", json.dumps({"embeds": [embed]}), webhook_url], + capture_output=True, + timeout=15 + ) + except subprocess.TimeoutExpired: + print("Warning: Discord webhook timed out", file=sys.stderr) + + +# ============================================================================= +# UTILITIES +# ============================================================================= + +def run_gh(args: list[str], check: bool = True) -> subprocess.CompletedProcess: + """Run a gh CLI command.""" + result = subprocess.run( + ["gh"] + args, + capture_output=True, + text=True, + check=False + ) + if check and result.returncode != 0: + print(f"Error running gh {' '.join(args)}: {result.stderr}", file=sys.stderr) + sys.exit(1) + return result + + +def run_git(args: list[str], cwd: str = None, check: bool = True) -> subprocess.CompletedProcess: + """Run a git command.""" + result = subprocess.run( + ["git"] + args, + capture_output=True, + text=True, + cwd=cwd, + check=False + ) + if check and result.returncode != 0: + print(f"Error running git {' '.join(args)}: {result.stderr}", file=sys.stderr) + return result + + +def should_ignore_file(path: str) -> bool: + """Check if a file should be ignored for overlap detection.""" + if path in IGNORE_FILES: + return True + basename = path.split("/")[-1] + return basename in IGNORE_FILES + + +def find_common_prefix(paths: list[str]) -> str: + """Find the common directory prefix of a list of file paths.""" + if not paths: + return "" + if len(paths) == 1: + parts = paths[0].rsplit('/', 1) + return parts[0] + '/' if len(parts) > 1 else "" + + split_paths = [p.split('/') for p in paths] + common = [] + for parts in zip(*split_paths): + if len(set(parts)) == 1: + common.append(parts[0]) + else: + break + + return '/'.join(common) + '/' if common else "" + + +def format_relative_time(iso_timestamp: str) -> str: + """Format an ISO timestamp as relative time.""" + if not iso_timestamp: + return "" + + from datetime import datetime, timezone + try: + dt = datetime.fromisoformat(iso_timestamp.replace('Z', '+00:00')) + now = datetime.now(timezone.utc) + diff = now - dt + + seconds = diff.total_seconds() + if seconds < 60: + return "just now" + elif seconds < 3600: + return f"{int(seconds / 60)}m ago" + elif seconds < 86400: + return f"{int(seconds / 3600)}h ago" + else: + return f"{int(seconds / 86400)}d ago" + except Exception as e: + print(f"Warning: Could not format relative time: {e}", file=sys.stderr) + return "" + + +# ============================================================================= +# DATA CLASSES +# ============================================================================= + +@dataclass +class ChangedFile: + """Represents a file changed in a PR.""" + path: str + additions: list[tuple[int, int]] + deletions: list[tuple[int, int]] + is_rename: bool = False + old_path: str = None + + +@dataclass +class PullRequest: + """Represents a pull request.""" + number: int + title: str + author: str + url: str + head_ref: str + base_ref: str + files: list[str] + changed_ranges: dict[str, ChangedFile] + updated_at: str = None + + +@dataclass +class ConflictInfo: + """Info about a single conflicting file.""" + path: str + conflict_count: int = 0 + conflict_lines: int = 0 + conflict_type: str = "content" + + +@dataclass +class Overlap: + """Represents an overlap between two PRs.""" + pr_a: PullRequest + pr_b: PullRequest + overlapping_files: list[str] + line_overlaps: dict[str, list[tuple[int, int]]] + has_merge_conflict: bool = False + conflict_files: list[str] = None + conflict_details: list[ConflictInfo] = None + conflict_type: str = None + + def __post_init__(self): + if self.conflict_files is None: + self.conflict_files = [] + if self.conflict_details is None: + self.conflict_details = [] + + +# ============================================================================= +# CONSTANTS +# ============================================================================= + +IGNORE_FILES = { + "autogpt_platform/frontend/src/app/api/openapi.json", + "poetry.lock", + "pnpm-lock.yaml", + "package-lock.json", + "yarn.lock", +} + + +# ============================================================================= +# ENTRY POINT +# ============================================================================= + +if __name__ == "__main__": + main() diff --git a/.github/workflows/claude-ci-failure-auto-fix.yml b/.github/workflows/claude-ci-failure-auto-fix.yml index ab07c8ae10..dbca6dc3f3 100644 --- a/.github/workflows/claude-ci-failure-auto-fix.yml +++ b/.github/workflows/claude-ci-failure-auto-fix.yml @@ -40,6 +40,48 @@ jobs: git checkout -b "$BRANCH_NAME" echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT + # Backend Python/Poetry setup (so Claude can run linting/tests) + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Set up Python dependency cache + uses: actions/cache@v5 + with: + path: ~/.cache/pypoetry + key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} + + - name: Install Poetry + run: | + cd autogpt_platform/backend + HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry) + curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install Python dependencies + working-directory: autogpt_platform/backend + run: poetry install + + - name: Generate Prisma Client + working-directory: autogpt_platform/backend + run: poetry run prisma generate && poetry run gen-prisma-stub + + # Frontend Node.js/pnpm setup (so Claude can run linting/tests) + - name: Enable corepack + run: corepack enable + + - name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: "22" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml + + - name: Install JavaScript dependencies + working-directory: autogpt_platform/frontend + run: pnpm install --frozen-lockfile + - name: Get CI failure details id: failure_details uses: actions/github-script@v8 diff --git a/.github/workflows/claude-dependabot.yml b/.github/workflows/claude-dependabot.yml index da37df6de7..274c6d2cab 100644 --- a/.github/workflows/claude-dependabot.yml +++ b/.github/workflows/claude-dependabot.yml @@ -77,27 +77,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index ee901fe5d4..8b8260af6b 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -93,27 +93,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 966243323c..ff535f8496 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -62,7 +62,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} @@ -93,6 +93,6 @@ jobs: exit 1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/docs-claude-review.yml b/.github/workflows/docs-claude-review.yml index ca2788b387..19d5dd667b 100644 --- a/.github/workflows/docs-claude-review.yml +++ b/.github/workflows/docs-claude-review.yml @@ -7,6 +7,10 @@ on: - "docs/integrations/**" - "autogpt_platform/backend/backend/blocks/**" +concurrency: + group: claude-docs-review-${{ github.event.pull_request.number }} + cancel-in-progress: true + jobs: claude-review: # Only run for PRs from members/collaborators @@ -91,5 +95,35 @@ jobs: 3. Read corresponding documentation files to verify accuracy 4. Provide your feedback as a PR comment + ## IMPORTANT: Comment Marker + Start your PR comment with exactly this HTML comment marker on its own line: + + + This marker is used to identify and replace your comment on subsequent runs. + Be constructive and specific. If everything looks good, say so! If there are issues, explain what's wrong and suggest how to fix it. + + - name: Delete old Claude review comments + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + # Get all comment IDs with our marker, sorted by creation date (oldest first) + COMMENT_IDS=$(gh api \ + repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \ + --jq '[.[] | select(.body | contains(""))] | sort_by(.created_at) | .[].id') + + # Count comments + COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true) + + if [ "$COMMENT_COUNT" -gt 1 ]; then + # Delete all but the last (newest) comment + echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do + if [ -n "$COMMENT_ID" ]; then + echo "Deleting old review comment: $COMMENT_ID" + gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID + fi + done + else + echo "No old review comments to clean up" + fi diff --git a/.github/workflows/pr-overlap-check.yml b/.github/workflows/pr-overlap-check.yml new file mode 100644 index 0000000000..c53f56321b --- /dev/null +++ b/.github/workflows/pr-overlap-check.yml @@ -0,0 +1,39 @@ +name: PR Overlap Detection + +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - dev + - master + +permissions: + contents: read + pull-requests: write + +jobs: + check-overlaps: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Need full history for merge testing + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Configure git + run: | + git config user.email "github-actions[bot]@users.noreply.github.com" + git config user.name "github-actions[bot]" + + - name: Run overlap detection + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # Always succeed - this check informs contributors, it shouldn't block merging + continue-on-error: true + run: | + python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }} diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index fa52ba812a..2711bd2df9 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -104,6 +104,12 @@ TWITTER_CLIENT_SECRET= # Make a new workspace for your OAuth APP -- trust me # https://linear.app/settings/api/applications/new # Callback URL: http://localhost:3000/auth/integrations/oauth_callback +LINEAR_API_KEY= +# Linear project and team IDs for the feature request tracker. +# Find these in your Linear workspace URL: linear.app//project/ +# and in team settings. Used by the chat copilot to file and search feature requests. +LINEAR_FEATURE_REQUEST_PROJECT_ID= +LINEAR_FEATURE_REQUEST_TEAM_ID= LINEAR_CLIENT_ID= LINEAR_CLIENT_SECRET= diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index ace534b730..05a8d4858b 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -66,13 +66,19 @@ ENV POETRY_HOME=/opt/poetry \ DEBIAN_FRONTEND=noninteractive ENV PATH=/opt/poetry/bin:$PATH -# Install Python, FFmpeg, and ImageMagick (required for video processing blocks) +# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use. +# bubblewrap provides OS-level sandbox (whitelist-only FS + no network) +# for the bash_exec MCP tool. # Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc. RUN apt-get update && apt-get install -y --no-install-recommends \ python3.13 \ python3-pip \ ffmpeg \ imagemagick \ + jq \ + ripgrep \ + tree \ + bubblewrap \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3* diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index 808692f97f..04bbe8e60d 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -27,12 +27,11 @@ class ChatConfig(BaseSettings): session_ttl: int = Field(default=43200, description="Session TTL in seconds") # Streaming Configuration - max_context_messages: int = Field( - default=50, ge=1, le=200, description="Maximum context messages" - ) - stream_timeout: int = Field(default=300, description="Stream timeout in seconds") - max_retries: int = Field(default=3, description="Maximum number of retries") + max_retries: int = Field( + default=3, + description="Max retries for fallback path (SDK handles retries internally)", + ) max_agent_runs: int = Field(default=30, description="Maximum number of agent runs") max_agent_schedules: int = Field( default=30, description="Maximum number of agent schedules" @@ -93,6 +92,31 @@ class ChatConfig(BaseSettings): description="Name of the prompt in Langfuse to fetch", ) + # Claude Agent SDK Configuration + use_claude_agent_sdk: bool = Field( + default=True, + description="Use Claude Agent SDK for chat completions", + ) + claude_agent_model: str | None = Field( + default=None, + description="Model for the Claude Agent SDK path. If None, derives from " + "the `model` field by stripping the OpenRouter provider prefix.", + ) + claude_agent_max_buffer_size: int = Field( + default=10 * 1024 * 1024, # 10MB (default SDK is 1MB) + description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. " + "Increase if tool outputs exceed the limit.", + ) + claude_agent_max_subtasks: int = Field( + default=10, + description="Max number of sub-agent Tasks the SDK can spawn per session.", + ) + claude_agent_use_resume: bool = Field( + default=True, + description="Use --resume for multi-turn conversations instead of " + "history compression. Falls back to compression when unavailable.", + ) + # Extended thinking configuration for Claude models thinking_enabled: bool = Field( default=True, @@ -138,6 +162,17 @@ class ChatConfig(BaseSettings): v = os.getenv("CHAT_INTERNAL_API_KEY") return v + @field_validator("use_claude_agent_sdk", mode="before") + @classmethod + def get_use_claude_agent_sdk(cls, v): + """Get use_claude_agent_sdk from environment if not provided.""" + # Check environment variable - default to True if not set + env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower() + if env_val: + return env_val in ("true", "1", "yes", "on") + # Default to True (SDK enabled by default) + return True if v is None else v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/api/features/chat/model.py index 35418f174f..30ac27aece 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/api/features/chat/model.py @@ -334,9 +334,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None: try: session = ChatSession.model_validate_json(raw_session) logger.info( - f"Loading session {session_id} from cache: " - f"message_count={len(session.messages)}, " - f"roles={[m.role for m in session.messages]}" + f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, " + f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles ) return session except Exception as e: @@ -378,11 +377,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None: return None messages = prisma_session.Messages - logger.info( - f"Loading session {session_id} from DB: " - f"has_messages={messages is not None}, " - f"message_count={len(messages) if messages else 0}, " - f"roles={[m.role for m in messages] if messages else []}" + logger.debug( + f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, " + f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles ) return ChatSession.from_db(prisma_session, messages) @@ -433,10 +430,9 @@ async def _save_session_to_db( "function_call": msg.function_call, } ) - logger.info( - f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: " - f"roles={[m['role'] for m in messages_data]}, " - f"start_sequence={existing_message_count}" + logger.debug( + f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, " + f"roles={[m['role'] for m in messages_data]}" ) await chat_db.add_chat_messages_batch( session_id=session.session_id, @@ -476,7 +472,7 @@ async def get_chat_session( logger.warning(f"Unexpected cache error for session {session_id}: {e}") # Fall back to database - logger.info(f"Session {session_id} not in cache, checking database") + logger.debug(f"Session {session_id} not in cache, checking database") session = await _get_session_from_db(session_id) if session is None: @@ -493,7 +489,6 @@ async def get_chat_session( # Cache the session from DB try: await _cache_session(session) - logger.info(f"Cached session {session_id} from database") except Exception as e: logger.warning(f"Failed to cache session {session_id}: {e}") @@ -558,6 +553,40 @@ async def upsert_chat_session( return session +async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession: + """Atomically append a message to a session and persist it. + + Acquires the session lock, re-fetches the latest session state, + appends the message, and saves — preventing message loss when + concurrent requests modify the same session. + """ + lock = await _get_session_lock(session_id) + + async with lock: + session = await get_chat_session(session_id) + if session is None: + raise ValueError(f"Session {session_id} not found") + + session.messages.append(message) + existing_message_count = await chat_db.get_chat_session_message_count( + session_id + ) + + try: + await _save_session_to_db(session, existing_message_count) + except Exception as e: + raise DatabaseError( + f"Failed to persist message to session {session_id}" + ) from e + + try: + await _cache_session(session) + except Exception as e: + logger.warning(f"Cache write failed for session {session_id}: {e}") + + return session + + async def create_chat_session(user_id: str) -> ChatSession: """Create a new chat session and persist it. @@ -664,13 +693,19 @@ async def update_session_title(session_id: str, title: str) -> bool: logger.warning(f"Session {session_id} not found for title update") return False - # Invalidate cache so next fetch gets updated title + # Update title in cache if it exists (instead of invalidating). + # This prevents race conditions where cache invalidation causes + # the frontend to see stale DB data while streaming is still in progress. try: - redis_key = _get_session_cache_key(session_id) - async_redis = await get_redis_async() - await async_redis.delete(redis_key) + cached = await _get_session_from_cache(session_id) + if cached: + cached.title = title + await _cache_session(cached) except Exception as e: - logger.warning(f"Failed to invalidate cache for session {session_id}: {e}") + # Not critical - title will be correct on next full cache refresh + logger.warning( + f"Failed to update title in cache for session {session_id}: {e}" + ) return True except Exception as e: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 0d8b12b0b7..aa565ca891 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,5 +1,6 @@ """Chat API routes for chat session management and streaming via SSE.""" +import asyncio import logging import uuid as uuid_module from collections.abc import AsyncGenerator @@ -11,13 +12,22 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError +from backend.util.feature_flag import Flag, is_feature_enabled from . import service as chat_service from . import stream_registry from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig -from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions -from .response_model import StreamFinish, StreamHeartbeat +from .model import ( + ChatMessage, + ChatSession, + append_and_save_message, + create_chat_session, + get_chat_session, + get_user_sessions, +) +from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart +from .sdk import service as sdk_service from .tools.models import ( AgentDetailsResponse, AgentOutputResponse, @@ -41,6 +51,7 @@ from .tools.models import ( SetupRequirementsResponse, UnderstandingUpdatedResponse, ) +from .tracking import track_user_message config = ChatConfig() @@ -232,6 +243,10 @@ async def get_session( active_task, last_message_id = await stream_registry.get_active_task_for_session( session_id, user_id ) + logger.info( + f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, " + f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}" + ) if active_task: # Filter out the in-progress assistant message from the session response. # The client will receive the complete assistant response through the SSE @@ -301,10 +316,9 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - session = await _validate_and_get_session(session_id, user_id) logger.info( - f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms", + f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -313,6 +327,25 @@ async def stream_chat_post( }, ) + # Atomically append user message to session BEFORE creating task to avoid + # race condition where GET_SESSION sees task as "running" but message isn't + # saved yet. append_and_save_message re-fetches inside a lock to prevent + # message loss from concurrent requests. + if request.message: + message = ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, + ) + if request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + logger.info(f"[STREAM] Saving user message to session {session_id}") + session = await append_and_save_message(session_id, message) + logger.info(f"[STREAM] User message saved for session {session_id}") + # Create a task in the stream registry for reconnection support task_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4()) @@ -328,7 +361,7 @@ async def stream_chat_post( operation_id=operation_id, ) logger.info( - f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms", + f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -349,15 +382,47 @@ async def stream_chat_post( first_chunk_time, ttfc = None, None chunk_count = 0 try: - async for chunk in chat_service.stream_chat_completion( + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + logger.info( + f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "elapsed_ms": (time_module.perf_counter() - gen_start_time) + * 1000, + } + }, + ) + + # Choose service based on LaunchDarkly flag (falls back to config default) + use_sdk = await is_feature_enabled( + Flag.COPILOT_SDK, + user_id or "anonymous", + default=config.use_claude_agent_sdk, + ) + stream_fn = ( + sdk_service.stream_chat_completion_sdk + if use_sdk + else chat_service.stream_chat_completion + ) + logger.info( + f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion", + extra={"json_fields": log_meta}, + ) + # Pass message=None since we already added it to the session above + async for chunk in stream_fn( session_id, - request.message, + None, # Message already in session is_user_message=request.is_user_message, user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch + session=session, # Pass session with message already added context=request.context, - _task_id=task_id, # Pass task_id so service emits start with taskId for reconnection ): + # Skip duplicate StreamStart — we already published one above + if isinstance(chunk, StreamStart): + continue chunk_count += 1 if first_chunk_time is None: first_chunk_time = time_module.perf_counter() @@ -378,7 +443,7 @@ async def stream_chat_post( gen_end_time = time_module.perf_counter() total_time = (gen_end_time - gen_start_time) * 1000 logger.info( - f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; " + f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; " f"task={task_id}, session={session_id}, " f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}", extra={ @@ -405,6 +470,17 @@ async def stream_chat_post( } }, ) + # Publish a StreamError so the frontend can display an error message + try: + await stream_registry.publish_chunk( + task_id, + StreamError( + errorText="An error occurred. Please try again.", + code="stream_error", + ), + ) + except Exception: + pass # Best-effort; mark_task_completed will publish StreamFinish await stream_registry.mark_task_completed(task_id, "failed") # Start the AI generation in a background task @@ -507,8 +583,14 @@ async def stream_chat_post( "json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)} }, ) + # Surface error to frontend so it doesn't appear stuck + yield StreamError( + errorText="An error occurred. Please try again.", + code="stream_error", + ).to_sse() + yield StreamFinish().to_sse() finally: - # Unsubscribe when client disconnects or stream ends to prevent resource leak + # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: await stream_registry.unsubscribe_from_task( @@ -752,8 +834,6 @@ async def stream_task( ) async def event_generator() -> AsyncGenerator[str, None]: - import asyncio - heartbeat_interval = 15.0 # Send heartbeat every 15 seconds try: while True: diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py new file mode 100644 index 0000000000..7d9d6371e9 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py @@ -0,0 +1,14 @@ +"""Claude Agent SDK integration for CoPilot. + +This module provides the integration layer between the Claude Agent SDK +and the existing CoPilot tool system, enabling drop-in replacement of +the current LLM orchestration with the battle-tested Claude Agent SDK. +""" + +from .service import stream_chat_completion_sdk +from .tool_adapter import create_copilot_mcp_server + +__all__ = [ + "stream_chat_completion_sdk", + "create_copilot_mcp_server", +] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py new file mode 100644 index 0000000000..f7151f8319 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py @@ -0,0 +1,203 @@ +"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + +This module provides the adapter layer that converts streaming messages from +the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that +the frontend expects. +""" + +import json +import logging +import uuid + +from claude_agent_sdk import ( + AssistantMessage, + Message, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) +from backend.api.features.chat.sdk.tool_adapter import ( + MCP_TOOL_PREFIX, + pop_pending_tool_output, +) + +logger = logging.getLogger(__name__) + + +class SDKResponseAdapter: + """Adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + + This class maintains state during a streaming session to properly track + text blocks, tool calls, and message lifecycle. + """ + + def __init__(self, message_id: str | None = None): + self.message_id = message_id or str(uuid.uuid4()) + self.text_block_id = str(uuid.uuid4()) + self.has_started_text = False + self.has_ended_text = False + self.current_tool_calls: dict[str, dict[str, str]] = {} + self.task_id: str | None = None + self.step_open = False + + def set_task_id(self, task_id: str) -> None: + """Set the task ID for reconnection support.""" + self.task_id = task_id + + def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]: + """Convert a single SDK message to Vercel AI SDK format.""" + responses: list[StreamBaseResponse] = [] + + if isinstance(sdk_message, SystemMessage): + if sdk_message.subtype == "init": + responses.append( + StreamStart(messageId=self.message_id, taskId=self.task_id) + ) + # Open the first step (matches non-SDK: StreamStart then StreamStartStep) + responses.append(StreamStartStep()) + self.step_open = True + + elif isinstance(sdk_message, AssistantMessage): + # After tool results, the SDK sends a new AssistantMessage for the + # next LLM turn. Open a new step if the previous one was closed. + if not self.step_open: + responses.append(StreamStartStep()) + self.step_open = True + + for block in sdk_message.content: + if isinstance(block, TextBlock): + if block.text: + self._ensure_text_started(responses) + responses.append( + StreamTextDelta(id=self.text_block_id, delta=block.text) + ) + + elif isinstance(block, ToolUseBlock): + self._end_text_if_open(responses) + + # Strip MCP prefix so frontend sees "find_block" + # instead of "mcp__copilot__find_block". + tool_name = block.name.removeprefix(MCP_TOOL_PREFIX) + + responses.append( + StreamToolInputStart(toolCallId=block.id, toolName=tool_name) + ) + responses.append( + StreamToolInputAvailable( + toolCallId=block.id, + toolName=tool_name, + input=block.input, + ) + ) + self.current_tool_calls[block.id] = {"name": tool_name} + + elif isinstance(sdk_message, UserMessage): + # UserMessage carries tool results back from tool execution. + content = sdk_message.content + blocks = content if isinstance(content, list) else [] + for block in blocks: + if isinstance(block, ToolResultBlock) and block.tool_use_id: + tool_info = self.current_tool_calls.get(block.tool_use_id, {}) + tool_name = tool_info.get("name", "unknown") + + # Prefer the stashed full output over the SDK's + # (potentially truncated) ToolResultBlock content. + # The SDK truncates large results, writing them to disk, + # which breaks frontend widget parsing. + output = pop_pending_tool_output(tool_name) or ( + _extract_tool_output(block.content) + ) + + responses.append( + StreamToolOutputAvailable( + toolCallId=block.tool_use_id, + toolName=tool_name, + output=output, + success=not (block.is_error or False), + ) + ) + + # Close the current step after tool results — the next + # AssistantMessage will open a new step for the continuation. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + elif isinstance(sdk_message, ResultMessage): + self._end_text_if_open(responses) + # Close the step before finishing. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + if sdk_message.subtype == "success": + responses.append(StreamFinish()) + elif sdk_message.subtype in ("error", "error_during_execution"): + error_msg = getattr(sdk_message, "result", None) or "Unknown error" + responses.append( + StreamError(errorText=str(error_msg), code="sdk_error") + ) + responses.append(StreamFinish()) + else: + logger.warning( + f"Unexpected ResultMessage subtype: {sdk_message.subtype}" + ) + responses.append(StreamFinish()) + + else: + logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}") + + return responses + + def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None: + """Start (or restart) a text block if needed.""" + if not self.has_started_text or self.has_ended_text: + if self.has_ended_text: + self.text_block_id = str(uuid.uuid4()) + self.has_ended_text = False + responses.append(StreamTextStart(id=self.text_block_id)) + self.has_started_text = True + + def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None: + """End the current text block if one is open.""" + if self.has_started_text and not self.has_ended_text: + responses.append(StreamTextEnd(id=self.text_block_id)) + self.has_ended_text = True + + +def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str: + """Extract a string output from a ToolResultBlock's content field.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [item.get("text", "") for item in content if item.get("type") == "text"] + if parts: + return "".join(parts) + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) + if content is None: + return "" + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py new file mode 100644 index 0000000000..a4f2502642 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py @@ -0,0 +1,366 @@ +"""Unit tests for the SDK response adapter.""" + +from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) + +from .response_adapter import SDKResponseAdapter +from .tool_adapter import MCP_TOOL_PREFIX + + +def _adapter() -> SDKResponseAdapter: + a = SDKResponseAdapter(message_id="msg-1") + a.set_task_id("task-1") + return a + + +# -- SystemMessage ----------------------------------------------------------- + + +def test_system_init_emits_start_and_step(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="init", data={})) + assert len(results) == 2 + assert isinstance(results[0], StreamStart) + assert results[0].messageId == "msg-1" + assert results[0].taskId == "task-1" + assert isinstance(results[1], StreamStartStep) + + +def test_system_non_init_emits_nothing(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="other", data={})) + assert results == [] + + +# -- AssistantMessage with TextBlock ----------------------------------------- + + +def test_text_block_emits_step_start_and_delta(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="hello")], model="test") + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "hello" + + +def test_empty_text_block_emits_only_step(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="")], model="test") + results = adapter.convert_message(msg) + # Empty text skipped, but step still opens + assert len(results) == 1 + assert isinstance(results[0], StreamStartStep) + + +def test_multiple_text_deltas_reuse_block_id(): + adapter = _adapter() + msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test") + msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test") + r1 = adapter.convert_message(msg1) + r2 = adapter.convert_message(msg2) + # First gets step+start+delta, second only delta (block & step already started) + assert len(r1) == 3 + assert isinstance(r1[0], StreamStartStep) + assert isinstance(r1[1], StreamTextStart) + assert len(r2) == 1 + assert isinstance(r2[0], StreamTextDelta) + assert r1[1].id == r2[0].id # same block ID + + +# -- AssistantMessage with ToolUseBlock -------------------------------------- + + +def test_tool_use_emits_input_start_and_available(): + """Tool names arrive with MCP prefix and should be stripped for the frontend.""" + adapter = _adapter() + msg = AssistantMessage( + content=[ + ToolUseBlock( + id="tool-1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"q": "x"}, + ) + ], + model="test", + ) + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamToolInputStart) + assert results[1].toolCallId == "tool-1" + assert results[1].toolName == "find_agent" # prefix stripped + assert isinstance(results[2], StreamToolInputAvailable) + assert results[2].toolName == "find_agent" # prefix stripped + assert results[2].input == {"q": "x"} + + +def test_text_then_tool_ends_text_block(): + adapter = _adapter() + text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test") + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + adapter.convert_message(text_msg) # opens step + text + results = adapter.convert_message(tool_msg) + # Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamToolInputStart) + + +# -- UserMessage with ToolResultBlock ---------------------------------------- + + +def test_tool_result_emits_output_and_finish_step(): + adapter = _adapter() + # First register the tool call (opens step) — SDK sends prefixed name + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})], + model="test", + ) + adapter.convert_message(tool_msg) + + # Now send tool result + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")] + ) + results = adapter.convert_message(result_msg) + assert len(results) == 2 + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].toolCallId == "t1" + assert results[0].toolName == "find_agent" # prefix stripped + assert results[0].output == "found 3 agents" + assert results[0].success is True + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_error(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={}) + ], + model="test", + ) + ) + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].success is False + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_list_content(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + result_msg = UserMessage( + content=[ + ToolResultBlock( + tool_use_id="t1", + content=[ + {"type": "text", "text": "line1"}, + {"type": "text", "text": "line2"}, + ], + ) + ] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].output == "line1line2" + assert isinstance(results[1], StreamFinishStep) + + +def test_string_user_message_ignored(): + """A plain string UserMessage (not tool results) produces no output.""" + adapter = _adapter() + results = adapter.convert_message(UserMessage(content="hello")) + assert results == [] + + +# -- ResultMessage ----------------------------------------------------------- + + +def test_result_success_emits_finish_step_and_finish(): + adapter = _adapter() + # Start some text first (opens step) + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="done")], model="test") + ) + msg = ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="s1", + ) + results = adapter.convert_message(msg) + # TextEnd + FinishStep + StreamFinish + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamFinishStep) + assert isinstance(results[2], StreamFinish) + + +def test_result_error_emits_error_and_finish(): + adapter = _adapter() + msg = ResultMessage( + subtype="error", + duration_ms=100, + duration_api_ms=50, + is_error=True, + num_turns=0, + session_id="s1", + result="API rate limited", + ) + results = adapter.convert_message(msg) + # No step was open, so no FinishStep — just Error + Finish + assert len(results) == 2 + assert isinstance(results[0], StreamError) + assert "API rate limited" in results[0].errorText + assert isinstance(results[1], StreamFinish) + + +# -- Text after tools (new block ID) ---------------------------------------- + + +def test_text_after_tool_gets_new_block_id(): + adapter = _adapter() + # Text -> Tool -> ToolResult -> Text should get a new text block ID and step + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="before")], model="test") + ) + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + # Send tool result (closes step) + adapter.convert_message( + UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")]) + ) + results = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="after")], model="test") + ) + # Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "after" + + +# -- Full conversation flow -------------------------------------------------- + + +def test_full_conversation_flow(): + """Simulate a complete conversation: init -> text -> tool -> result -> text -> finish.""" + adapter = _adapter() + all_responses: list[StreamBaseResponse] = [] + + # 1. Init + all_responses.extend( + adapter.convert_message(SystemMessage(subtype="init", data={})) + ) + # 2. Assistant text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="Let me search")], model="test") + ) + ) + # 3. Tool use + all_responses.extend( + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock( + id="t1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"query": "email"}, + ) + ], + model="test", + ) + ) + ) + # 4. Tool result + all_responses.extend( + adapter.convert_message( + UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")] + ) + ) + ) + # 5. More text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="I found 2")], model="test") + ) + ) + # 6. Result + all_responses.extend( + adapter.convert_message( + ResultMessage( + subtype="success", + duration_ms=500, + duration_api_ms=400, + is_error=False, + num_turns=2, + session_id="s1", + ) + ) + ) + + types = [type(r).__name__ for r in all_responses] + assert types == [ + "StreamStart", + "StreamStartStep", # step 1: text + tool call + "StreamTextStart", + "StreamTextDelta", # "Let me search" + "StreamTextEnd", # closed before tool + "StreamToolInputStart", + "StreamToolInputAvailable", + "StreamToolOutputAvailable", # tool result + "StreamFinishStep", # step 1 closed after tool result + "StreamStartStep", # step 2: continuation text + "StreamTextStart", # new block after tool + "StreamTextDelta", # "I found 2" + "StreamTextEnd", # closed by result + "StreamFinishStep", # step 2 closed + "StreamFinish", + ] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py new file mode 100644 index 0000000000..89853402a3 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py @@ -0,0 +1,305 @@ +"""Security hooks for Claude Agent SDK integration. + +This module provides security hooks that validate tool calls before execution, +ensuring multi-user isolation and preventing unauthorized operations. +""" + +import json +import logging +import os +import re +from collections.abc import Callable +from typing import Any, cast + +from backend.api.features.chat.sdk.tool_adapter import ( + BLOCKED_TOOLS, + DANGEROUS_PATTERNS, + MCP_TOOL_PREFIX, + WORKSPACE_SCOPED_TOOLS, +) + +logger = logging.getLogger(__name__) + + +def _deny(reason: str) -> dict[str, Any]: + """Return a hook denial response.""" + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": reason, + } + } + + +def _validate_workspace_path( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None +) -> dict[str, Any]: + """Validate that a workspace-scoped tool only accesses allowed paths. + + Allowed directories: + - The SDK working directory (``/tmp/copilot-/``) + - The SDK tool-results directory (``~/.claude/projects/…/tool-results/``) + """ + path = tool_input.get("file_path") or tool_input.get("path") or "" + if not path: + # Glob/Grep without a path default to cwd which is already sandboxed + return {} + + # Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM + # naturally uses relative paths like "test.txt" instead of absolute ones). + # Tilde paths (~/) are home-dir references, not relative — expand first. + if path.startswith("~"): + resolved = os.path.realpath(os.path.expanduser(path)) + elif not os.path.isabs(path) and sdk_cwd: + resolved = os.path.realpath(os.path.join(sdk_cwd, path)) + else: + resolved = os.path.realpath(path) + + # Allow access within the SDK working directory + if sdk_cwd: + norm_cwd = os.path.realpath(sdk_cwd) + if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd: + return {} + + # Allow access to ~/.claude/projects/*/tool-results/ (big tool results) + claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects")) + tool_results_seg = os.sep + "tool-results" + os.sep + if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved: + return {} + + logger.warning( + f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})" + ) + workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else "" + return _deny( + f"[SECURITY] Tool '{tool_name}' can only access files within the workspace " + f"directory.{workspace_hint} " + "This is enforced by the platform and cannot be bypassed." + ) + + +def _validate_tool_access( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None +) -> dict[str, Any]: + """Validate that a tool call is allowed. + + Returns: + Empty dict to allow, or dict with hookSpecificOutput to deny + """ + # Block forbidden tools + if tool_name in BLOCKED_TOOLS: + logger.warning(f"Blocked tool access attempt: {tool_name}") + return _deny( + f"[SECURITY] Tool '{tool_name}' is blocked for security. " + "This is enforced by the platform and cannot be bypassed. " + "Use the CoPilot-specific MCP tools instead." + ) + + # Workspace-scoped tools: allowed only within the SDK workspace directory + if tool_name in WORKSPACE_SCOPED_TOOLS: + return _validate_workspace_path(tool_name, tool_input, sdk_cwd) + + # Check for dangerous patterns in tool input + # Use json.dumps for predictable format (str() produces Python repr) + input_str = json.dumps(tool_input) if tool_input else "" + + for pattern in DANGEROUS_PATTERNS: + if re.search(pattern, input_str, re.IGNORECASE): + logger.warning( + f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}" + ) + return _deny( + "[SECURITY] Input contains a blocked pattern. " + "This is enforced by the platform and cannot be bypassed." + ) + + return {} + + +def _validate_user_isolation( + tool_name: str, tool_input: dict[str, Any], user_id: str | None +) -> dict[str, Any]: + """Validate that tool calls respect user isolation.""" + # For workspace file tools, ensure path doesn't escape + if "workspace" in tool_name.lower(): + path = tool_input.get("path", "") or tool_input.get("file_path", "") + if path: + # Check for path traversal + if ".." in path or path.startswith("/"): + logger.warning( + f"Blocked path traversal attempt: {path} by user {user_id}" + ) + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": "Path traversal not allowed", + } + } + + return {} + + +def create_security_hooks( + user_id: str | None, + sdk_cwd: str | None = None, + max_subtasks: int = 3, + on_stop: Callable[[str, str], None] | None = None, +) -> dict[str, Any]: + """Create the security hooks configuration for Claude Agent SDK. + + Includes security validation and observability hooks: + - PreToolUse: Security validation before tool execution + - PostToolUse: Log successful tool executions + - PostToolUseFailure: Log and handle failed tool executions + - PreCompact: Log context compaction events (SDK handles compaction automatically) + - Stop: Capture transcript path for stateless resume (when *on_stop* is provided) + + Args: + user_id: Current user ID for isolation validation + sdk_cwd: SDK working directory for workspace-scoped tool validation + max_subtasks: Maximum Task (sub-agent) spawns allowed per session + on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when + the SDK finishes processing — used to read the JSONL transcript + before the CLI process exits. + + Returns: + Hooks configuration dict for ClaudeAgentOptions + """ + try: + from claude_agent_sdk import HookMatcher + from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput + + # Per-session counter for Task sub-agent spawns + task_spawn_count = 0 + + async def pre_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Combined pre-tool-use validation hook.""" + nonlocal task_spawn_count + _ = context # unused but required by signature + tool_name = cast(str, input_data.get("tool_name", "")) + tool_input = cast(dict[str, Any], input_data.get("tool_input", {})) + + # Rate-limit Task (sub-agent) spawns per session + if tool_name == "Task": + task_spawn_count += 1 + if task_spawn_count > max_subtasks: + logger.warning( + f"[SDK] Task limit reached ({max_subtasks}), user={user_id}" + ) + return cast( + SyncHookJSONOutput, + _deny( + f"Maximum {max_subtasks} sub-tasks per session. " + "Please continue in the main conversation." + ), + ) + + # Strip MCP prefix for consistent validation + is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX) + clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX) + + # Only block non-CoPilot tools; our MCP-registered tools + # (including Read for oversized results) are already sandboxed. + if not is_copilot_tool: + result = _validate_tool_access(clean_name, tool_input, sdk_cwd) + if result: + return cast(SyncHookJSONOutput, result) + + # Validate user isolation + result = _validate_user_isolation(clean_name, tool_input, user_id) + if result: + return cast(SyncHookJSONOutput, result) + + logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log successful tool executions for observability.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_failure_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log failed tool executions for debugging.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + error = input_data.get("error", "Unknown error") + logger.warning( + f"[SDK] Tool failed: {tool_name}, error={error}, " + f"user={user_id}, tool_use_id={tool_use_id}" + ) + return cast(SyncHookJSONOutput, {}) + + async def pre_compact_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log when SDK triggers context compaction. + + The SDK automatically compacts conversation history when it grows too large. + This hook provides visibility into when compaction happens. + """ + _ = context, tool_use_id + trigger = input_data.get("trigger", "auto") + logger.info( + f"[SDK] Context compaction triggered: {trigger}, user={user_id}" + ) + return cast(SyncHookJSONOutput, {}) + + # --- Stop hook: capture transcript path for stateless resume --- + async def stop_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Capture transcript path when SDK finishes processing. + + The Stop hook fires while the CLI process is still alive, giving us + a reliable window to read the JSONL transcript before SIGTERM. + """ + _ = context, tool_use_id + transcript_path = cast(str, input_data.get("transcript_path", "")) + sdk_session_id = cast(str, input_data.get("session_id", "")) + + if transcript_path and on_stop: + logger.info( + f"[SDK] Stop hook: transcript_path={transcript_path}, " + f"sdk_session_id={sdk_session_id[:12]}..." + ) + on_stop(transcript_path, sdk_session_id) + + return cast(SyncHookJSONOutput, {}) + + hooks: dict[str, Any] = { + "PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])], + "PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])], + "PostToolUseFailure": [ + HookMatcher(matcher="*", hooks=[post_tool_failure_hook]) + ], + "PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])], + } + + if on_stop is not None: + hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])] + + return hooks + except ImportError: + # Fallback for when SDK isn't available - return empty hooks + logger.warning("claude-agent-sdk not available, security hooks disabled") + return {} diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py new file mode 100644 index 0000000000..2d09afdab7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py @@ -0,0 +1,165 @@ +"""Unit tests for SDK security hooks.""" + +import os + +from .security_hooks import _validate_tool_access, _validate_user_isolation + +SDK_CWD = "/tmp/copilot-abc123" + + +def _is_denied(result: dict) -> bool: + hook = result.get("hookSpecificOutput", {}) + return hook.get("permissionDecision") == "deny" + + +# -- Blocked tools ----------------------------------------------------------- + + +def test_blocked_tools_denied(): + for tool in ("bash", "shell", "exec", "terminal", "command"): + result = _validate_tool_access(tool, {}) + assert _is_denied(result), f"{tool} should be blocked" + + +def test_unknown_tool_allowed(): + result = _validate_tool_access("SomeCustomTool", {}) + assert result == {} + + +# -- Workspace-scoped tools -------------------------------------------------- + + +def test_read_within_workspace_allowed(): + result = _validate_tool_access( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_write_within_workspace_allowed(): + result = _validate_tool_access( + "Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_edit_within_workspace_allowed(): + result = _validate_tool_access( + "Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_glob_within_workspace_allowed(): + result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_grep_within_workspace_allowed(): + result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_outside_workspace_denied(): + result = _validate_tool_access( + "Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_write_outside_workspace_denied(): + result = _validate_tool_access( + "Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_traversal_attack_denied(): + result = _validate_tool_access( + "Read", + {"file_path": f"{SDK_CWD}/../../etc/passwd"}, + sdk_cwd=SDK_CWD, + ) + assert _is_denied(result) + + +def test_no_path_allowed(): + """Glob/Grep without a path argument defaults to cwd — should pass.""" + result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_no_cwd_denies_absolute(): + """If no sdk_cwd is set, absolute paths are denied.""" + result = _validate_tool_access("Read", {"file_path": "/tmp/anything"}) + assert _is_denied(result) + + +# -- Tool-results directory -------------------------------------------------- + + +def test_read_tool_results_allowed(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_claude_projects_without_tool_results_denied(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Built-in Bash is blocked (use bash_exec MCP tool instead) --------------- + + +def test_bash_builtin_always_blocked(): + """SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead.""" + result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Dangerous patterns ------------------------------------------------------ + + +def test_dangerous_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"}) + assert _is_denied(result) + + +def test_subprocess_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"}) + assert _is_denied(result) + + +# -- User isolation ---------------------------------------------------------- + + +def test_workspace_path_traversal_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_absolute_path_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "/etc/passwd"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_normal_path_allowed(): + result = _validate_user_isolation( + "workspace_read", {"path": "src/main.py"}, user_id="user-1" + ) + assert result == {} + + +def test_non_workspace_tool_passes_isolation(): + result = _validate_user_isolation( + "find_agent", {"query": "email"}, user_id="user-1" + ) + assert result == {} diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py new file mode 100644 index 0000000000..65c4cebb06 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -0,0 +1,752 @@ +"""Claude Agent SDK service layer for CoPilot chat completions.""" + +import asyncio +import json +import logging +import os +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Any + +from backend.util.exceptions import NotFoundError + +from .. import stream_registry +from ..config import ChatConfig +from ..model import ( + ChatMessage, + ChatSession, + get_chat_session, + update_session_title, + upsert_chat_session, +) +from ..response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamStart, + StreamTextDelta, + StreamToolInputAvailable, + StreamToolOutputAvailable, +) +from ..service import ( + _build_system_prompt, + _execute_long_running_tool_with_streaming, + _generate_session_title, +) +from ..tools.models import OperationPendingResponse, OperationStartedResponse +from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path +from ..tracking import track_user_message +from .response_adapter import SDKResponseAdapter +from .security_hooks import create_security_hooks +from .tool_adapter import ( + COPILOT_TOOL_NAMES, + SDK_DISALLOWED_TOOLS, + LongRunningCallback, + create_copilot_mcp_server, + set_execution_context, +) +from .transcript import ( + download_transcript, + read_transcript_file, + upload_transcript, + validate_transcript, + write_transcript_to_tempfile, +) + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Set to hold background tasks to prevent garbage collection +_background_tasks: set[asyncio.Task[Any]] = set() + + +@dataclass +class CapturedTranscript: + """Info captured by the SDK Stop hook for stateless --resume.""" + + path: str = "" + sdk_session_id: str = "" + + @property + def available(self) -> bool: + return bool(self.path) + + +_SDK_CWD_PREFIX = WORKSPACE_PREFIX + +# Appended to the system prompt to inform the agent about available tools. +# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead, +# which has kernel-level network isolation (unshare --net). +_SDK_TOOL_SUPPLEMENT = """ + +## Tool notes + +- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool + for shell commands — it runs in a network-isolated sandbox. +- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the + same working directory. Files created by one are readable by the other. + These files are **ephemeral** — they exist only for the current session. +- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file` + for files that should persist across sessions (stored in cloud storage). +- Long-running tools (create_agent, edit_agent, etc.) are handled + asynchronously. You will receive an immediate response; the actual result + is delivered to the user via a background stream. +""" + + +def _build_long_running_callback(user_id: str | None) -> LongRunningCallback: + """Build a callback that delegates long-running tools to the non-SDK infrastructure. + + Long-running tools (create_agent, edit_agent, etc.) are delegated to the + existing background infrastructure: stream_registry (Redis Streams), + database persistence, and SSE reconnection. This means results survive + page refreshes / pod restarts, and the frontend shows the proper loading + widget with progress updates. + + The returned callback matches the ``LongRunningCallback`` signature: + ``(tool_name, args, session) -> MCP response dict``. + """ + + async def _callback( + tool_name: str, args: dict[str, Any], session: ChatSession + ) -> dict[str, Any]: + operation_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}" + session_id = session.session_id + + # --- Build user-friendly messages (matches non-SDK service) --- + if tool_name == "create_agent": + desc = args.get("description", "") + desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc + pending_msg = ( + f"Creating your agent: {desc_preview}" + if desc_preview + else "Creating agent... This may take a few minutes." + ) + started_msg = ( + "Agent creation started. You can close this tab - " + "check your library in a few minutes." + ) + elif tool_name == "edit_agent": + changes = args.get("changes", "") + changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes + pending_msg = ( + f"Editing agent: {changes_preview}" + if changes_preview + else "Editing agent... This may take a few minutes." + ) + started_msg = ( + "Agent edit started. You can close this tab - " + "check your library in a few minutes." + ) + else: + pending_msg = f"Running {tool_name}... This may take a few minutes." + started_msg = ( + f"{tool_name} started. You can close this tab - " + "check back in a few minutes." + ) + + # --- Register task in Redis for SSE reconnection --- + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # --- Save OperationPendingResponse to chat history --- + pending_message = ChatMessage( + role="tool", + content=OperationPendingResponse( + message=pending_msg, + operation_id=operation_id, + tool_name=tool_name, + ).model_dump_json(), + tool_call_id=tool_call_id, + ) + session.messages.append(pending_message) + await upsert_chat_session(session) + + # --- Spawn background task (reuses non-SDK infrastructure) --- + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( + tool_name=tool_name, + parameters=args, + tool_call_id=tool_call_id, + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + user_id=user_id, + ) + ) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + logger.info( + f"[SDK] Long-running tool {tool_name} delegated to background " + f"(operation_id={operation_id}, task_id={task_id})" + ) + + # --- Return OperationStartedResponse as MCP tool result --- + # This flows through SDK → response adapter → frontend, triggering + # the loading widget with SSE reconnection support. + started_json = OperationStartedResponse( + message=started_msg, + operation_id=operation_id, + tool_name=tool_name, + task_id=task_id, + ).model_dump_json() + + return { + "content": [{"type": "text", "text": started_json}], + "isError": False, + } + + return _callback + + +def _resolve_sdk_model() -> str | None: + """Resolve the model name for the Claude Agent SDK CLI. + + Uses ``config.claude_agent_model`` if set, otherwise derives from + ``config.model`` by stripping the OpenRouter provider prefix (e.g., + ``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``). + """ + if config.claude_agent_model: + return config.claude_agent_model + model = config.model + if "/" in model: + return model.split("/", 1)[1] + return model + + +def _build_sdk_env() -> dict[str, str]: + """Build env vars for the SDK CLI process. + + Routes API calls through OpenRouter (or a custom base_url) using + the same ``config.api_key`` / ``config.base_url`` as the non-SDK path. + This gives per-call token and cost tracking on the OpenRouter dashboard. + + Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth + token are both present — otherwise returns an empty dict so the SDK + falls back to its default credentials. + """ + env: dict[str, str] = {} + if config.api_key and config.base_url: + # Strip /v1 suffix — SDK expects the base URL without a version path + base = config.base_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + if not base or not base.startswith("http"): + # Invalid base_url — don't override SDK defaults + return env + env["ANTHROPIC_BASE_URL"] = base + env["ANTHROPIC_AUTH_TOKEN"] = config.api_key + # Must be explicitly empty so the CLI uses AUTH_TOKEN instead + env["ANTHROPIC_API_KEY"] = "" + return env + + +def _make_sdk_cwd(session_id: str) -> str: + """Create a safe, session-specific working directory path. + + Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path` + (single source of truth for path sanitization) and adds a defence-in-depth + assertion. + """ + cwd = make_session_path(session_id) + # Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer + cwd = os.path.normpath(cwd) + if not cwd.startswith(_SDK_CWD_PREFIX): + raise ValueError(f"SDK cwd escaped prefix: {cwd}") + return cwd + + +def _cleanup_sdk_tool_results(cwd: str) -> None: + """Remove SDK tool-result files for a specific session working directory. + + The SDK creates tool-result files under ~/.claude/projects//tool-results/. + We clean only the specific cwd's results to avoid race conditions between + concurrent sessions. + + Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id. + """ + import shutil + + # Validate cwd is under the expected prefix + normalized = os.path.normpath(cwd) + if not normalized.startswith(_SDK_CWD_PREFIX): + logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}") + return + + # SDK encodes the cwd path by replacing '/' with '-' + encoded_cwd = normalized.replace("/", "-") + + # Construct the project directory path (known-safe home expansion) + claude_projects = os.path.expanduser("~/.claude/projects") + project_dir = os.path.join(claude_projects, encoded_cwd) + + # Security check 3: Validate project_dir is under ~/.claude/projects + project_dir = os.path.normpath(project_dir) + if not project_dir.startswith(claude_projects): + logger.warning( + f"[SDK] Rejecting cleanup for escaped project path: {project_dir}" + ) + return + + results_dir = os.path.join(project_dir, "tool-results") + if os.path.isdir(results_dir): + for filename in os.listdir(results_dir): + file_path = os.path.join(results_dir, filename) + try: + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + pass + + # Also clean up the temp cwd directory itself + try: + shutil.rmtree(normalized, ignore_errors=True) + except OSError: + pass + + +async def _compress_conversation_history( + session: ChatSession, +) -> list[ChatMessage]: + """Compress prior conversation messages if they exceed the token threshold. + + Uses the shared compress_context() from prompt.py which supports: + - LLM summarization of old messages (keeps recent ones intact) + - Progressive content truncation as fallback + - Middle-out deletion as last resort + + Returns the compressed prior messages (everything except the current message). + """ + prior = session.messages[:-1] + if len(prior) < 2: + return prior + + from backend.util.prompt import compress_context + + # Convert ChatMessages to dicts for compress_context + messages_dict = [] + for msg in prior: + msg_dict: dict[str, Any] = {"role": msg.role} + if msg.content: + msg_dict["content"] = msg.content + if msg.tool_calls: + msg_dict["tool_calls"] = msg.tool_calls + if msg.tool_call_id: + msg_dict["tool_call_id"] = msg.tool_call_id + messages_dict.append(msg_dict) + + try: + import openai + + async with openai.AsyncOpenAI( + api_key=config.api_key, base_url=config.base_url, timeout=30.0 + ) as client: + result = await compress_context( + messages=messages_dict, + model=config.model, + client=client, + ) + except Exception as e: + logger.warning(f"[SDK] Context compression with LLM failed: {e}") + # Fall back to truncation-only (no LLM summarization) + result = await compress_context( + messages=messages_dict, + model=config.model, + client=None, + ) + + if result.was_compacted: + logger.info( + f"[SDK] Context compacted: {result.original_token_count} -> " + f"{result.token_count} tokens " + f"({result.messages_summarized} summarized, " + f"{result.messages_dropped} dropped)" + ) + # Convert compressed dicts back to ChatMessages + return [ + ChatMessage( + role=m["role"], + content=m.get("content"), + tool_calls=m.get("tool_calls"), + tool_call_id=m.get("tool_call_id"), + ) + for m in result.messages + ] + + return prior + + +def _format_conversation_context(messages: list[ChatMessage]) -> str | None: + """Format conversation messages into a context prefix for the user message. + + Returns a string like: + + User: hello + You responded: Hi! How can I help? + + + Returns None if there are no messages to format. + """ + if not messages: + return None + + lines: list[str] = [] + for msg in messages: + if not msg.content: + continue + if msg.role == "user": + lines.append(f"User: {msg.content}") + elif msg.role == "assistant": + lines.append(f"You responded: {msg.content}") + # Skip tool messages — they're internal details + + if not lines: + return None + + return "\n" + "\n".join(lines) + "\n" + + +async def stream_chat_completion_sdk( + session_id: str, + message: str | None = None, + tool_call_response: str | None = None, # noqa: ARG001 + is_user_message: bool = True, + user_id: str | None = None, + retry_count: int = 0, # noqa: ARG001 + session: ChatSession | None = None, + context: dict[str, str] | None = None, # noqa: ARG001 +) -> AsyncGenerator[StreamBaseResponse, None]: + """Stream chat completion using Claude Agent SDK. + + Drop-in replacement for stream_chat_completion with improved reliability. + """ + + if session is None: + session = await get_chat_session(session_id, user_id) + + if not session: + raise NotFoundError( + f"Session {session_id} not found. Please create a new session first." + ) + + if message: + session.messages.append( + ChatMessage( + role="user" if is_user_message else "assistant", content=message + ) + ) + if is_user_message: + track_user_message( + user_id=user_id, session_id=session_id, message_length=len(message) + ) + + session = await upsert_chat_session(session) + + # Generate title for new sessions (first user message) + if is_user_message and not session.title: + user_messages = [m for m in session.messages if m.role == "user"] + if len(user_messages) == 1: + first_message = user_messages[0].content or message or "" + if first_message: + task = asyncio.create_task( + _update_title_async(session_id, first_message, user_id) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + + # Build system prompt (reuses non-SDK path with Langfuse support) + has_history = len(session.messages) > 1 + system_prompt, _ = await _build_system_prompt( + user_id, has_conversation_history=has_history + ) + system_prompt += _SDK_TOOL_SUPPLEMENT + message_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + + yield StreamStart(messageId=message_id, taskId=task_id) + + stream_completed = False + # Initialise sdk_cwd before the try so the finally can reference it + # even if _make_sdk_cwd raises (in that case it stays as ""). + sdk_cwd = "" + use_resume = False + + try: + # Use a session-specific temp dir to avoid cleanup race conditions + # between concurrent sessions. + sdk_cwd = _make_sdk_cwd(session_id) + os.makedirs(sdk_cwd, exist_ok=True) + + set_execution_context( + user_id, + session, + long_running_callback=_build_long_running_callback(user_id), + ) + try: + from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient + + # Fail fast when no API credentials are available at all + sdk_env = _build_sdk_env() + if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"): + raise RuntimeError( + "No API key configured. Set OPEN_ROUTER_API_KEY " + "(or CHAT_API_KEY) for OpenRouter routing, " + "or ANTHROPIC_API_KEY for direct Anthropic access." + ) + + mcp_server = create_copilot_mcp_server() + + sdk_model = _resolve_sdk_model() + + # --- Transcript capture via Stop hook --- + captured_transcript = CapturedTranscript() + + def _on_stop(transcript_path: str, sdk_session_id: str) -> None: + captured_transcript.path = transcript_path + captured_transcript.sdk_session_id = sdk_session_id + + security_hooks = create_security_hooks( + user_id, + sdk_cwd=sdk_cwd, + max_subtasks=config.claude_agent_max_subtasks, + on_stop=_on_stop if config.claude_agent_use_resume else None, + ) + + # --- Resume strategy: download transcript from bucket --- + resume_file: str | None = None + use_resume = False + + if config.claude_agent_use_resume and user_id and len(session.messages) > 1: + transcript_content = await download_transcript(user_id, session_id) + if transcript_content and validate_transcript(transcript_content): + resume_file = write_transcript_to_tempfile( + transcript_content, session_id, sdk_cwd + ) + if resume_file: + use_resume = True + logger.info( + f"[SDK] Using --resume with transcript " + f"({len(transcript_content)} bytes)" + ) + + sdk_options_kwargs: dict[str, Any] = { + "system_prompt": system_prompt, + "mcp_servers": {"copilot": mcp_server}, + "allowed_tools": COPILOT_TOOL_NAMES, + "disallowed_tools": SDK_DISALLOWED_TOOLS, + "hooks": security_hooks, + "cwd": sdk_cwd, + "max_buffer_size": config.claude_agent_max_buffer_size, + } + if sdk_env: + sdk_options_kwargs["model"] = sdk_model + sdk_options_kwargs["env"] = sdk_env + if use_resume and resume_file: + sdk_options_kwargs["resume"] = resume_file + + options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] + + adapter = SDKResponseAdapter(message_id=message_id) + adapter.set_task_id(task_id) + + async with ClaudeSDKClient(options=options) as client: + current_message = message or "" + if not current_message and session.messages: + last_user = [m for m in session.messages if m.role == "user"] + if last_user: + current_message = last_user[-1].content or "" + + if not current_message.strip(): + yield StreamError( + errorText="Message cannot be empty.", + code="empty_prompt", + ) + yield StreamFinish() + return + + # Build query: with --resume the CLI already has full + # context, so we only send the new message. Without + # resume, compress history into a context prefix. + query_message = current_message + if not use_resume and len(session.messages) > 1: + logger.warning( + f"[SDK] Using compression fallback for session " + f"{session_id} ({len(session.messages)} messages) — " + f"no transcript available for --resume" + ) + compressed = await _compress_conversation_history(session) + history_context = _format_conversation_context(compressed) + if history_context: + query_message = ( + f"{history_context}\n\n" + f"Now, the user says:\n{current_message}" + ) + + logger.info( + f"[SDK] Sending query ({len(session.messages)} msgs in session)" + ) + logger.debug(f"[SDK] Query preview: {current_message[:80]!r}") + await client.query(query_message, session_id=session_id) + + assistant_response = ChatMessage(role="assistant", content="") + accumulated_tool_calls: list[dict[str, Any]] = [] + has_appended_assistant = False + has_tool_results = False + + async for sdk_msg in client.receive_messages(): + logger.debug( + f"[SDK] Received: {type(sdk_msg).__name__} " + f"{getattr(sdk_msg, 'subtype', '')}" + ) + for response in adapter.convert_message(sdk_msg): + if isinstance(response, StreamStart): + continue + + yield response + + if isinstance(response, StreamTextDelta): + delta = response.delta or "" + # After tool results, start a new assistant + # message for the post-tool text. + if has_tool_results and has_appended_assistant: + assistant_response = ChatMessage( + role="assistant", content=delta + ) + accumulated_tool_calls = [] + has_appended_assistant = False + has_tool_results = False + session.messages.append(assistant_response) + has_appended_assistant = True + else: + assistant_response.content = ( + assistant_response.content or "" + ) + delta + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolInputAvailable): + accumulated_tool_calls.append( + { + "id": response.toolCallId, + "type": "function", + "function": { + "name": response.toolName, + "arguments": json.dumps(response.input or {}), + }, + } + ) + assistant_response.tool_calls = accumulated_tool_calls + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolOutputAvailable): + session.messages.append( + ChatMessage( + role="tool", + content=( + response.output + if isinstance(response.output, str) + else str(response.output) + ), + tool_call_id=response.toolCallId, + ) + ) + has_tool_results = True + + elif isinstance(response, StreamFinish): + stream_completed = True + + if stream_completed: + break + + if ( + assistant_response.content or assistant_response.tool_calls + ) and not has_appended_assistant: + session.messages.append(assistant_response) + + # --- Capture transcript while CLI is still alive --- + # Must happen INSIDE async with: close() sends SIGTERM + # which kills the CLI before it can flush the JSONL. + if ( + config.claude_agent_use_resume + and user_id + and captured_transcript.available + ): + # Give CLI time to flush JSONL writes before we read + await asyncio.sleep(0.5) + raw_transcript = read_transcript_file(captured_transcript.path) + if raw_transcript: + task = asyncio.create_task( + _upload_transcript_bg(user_id, session_id, raw_transcript) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + else: + logger.debug("[SDK] Stop hook fired but transcript not usable") + + except ImportError: + raise RuntimeError( + "claude-agent-sdk is not installed. " + "Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) " + "to use the OpenAI-compatible fallback." + ) + + await upsert_chat_session(session) + logger.debug( + f"[SDK] Session {session_id} saved with {len(session.messages)} messages" + ) + if not stream_completed: + yield StreamFinish() + + except Exception as e: + logger.error(f"[SDK] Error: {e}", exc_info=True) + try: + await upsert_chat_session(session) + except Exception as save_err: + logger.error(f"[SDK] Failed to save session on error: {save_err}") + yield StreamError( + errorText="An error occurred. Please try again.", + code="sdk_error", + ) + yield StreamFinish() + finally: + if sdk_cwd: + _cleanup_sdk_tool_results(sdk_cwd) + + +async def _upload_transcript_bg( + user_id: str, session_id: str, raw_content: str +) -> None: + """Background task to strip progress entries and upload transcript.""" + try: + await upload_transcript(user_id, session_id, raw_content) + except Exception as e: + logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}") + + +async def _update_title_async( + session_id: str, message: str, user_id: str | None = None +) -> None: + """Background task to update session title.""" + try: + title = await _generate_session_title( + message, user_id=user_id, session_id=session_id + ) + if title: + await update_session_title(session_id, title) + logger.debug(f"[SDK] Generated title for {session_id}: {title}") + except Exception as e: + logger.warning(f"[SDK] Failed to update session title: {e}") diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py new file mode 100644 index 0000000000..2d259730bf --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py @@ -0,0 +1,363 @@ +"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools. + +This module provides the adapter layer that converts existing BaseTool implementations +into in-process MCP tools that can be used with the Claude Agent SDK. + +Long-running tools (``is_long_running=True``) are delegated to the non-SDK +background infrastructure (stream_registry, Redis persistence, SSE reconnection) +via a callback provided by the service layer. This avoids wasteful SDK polling +and makes results survive page refreshes. +""" + +import itertools +import json +import logging +import os +import uuid +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools import TOOL_REGISTRY +from backend.api.features.chat.tools.base import BaseTool + +logger = logging.getLogger(__name__) + +# Allowed base directory for the Read tool (SDK saves oversized tool results here). +# Restricted to ~/.claude/projects/ and further validated to require "tool-results" +# in the path — prevents reading settings, credentials, or other sensitive files. +_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/") + +# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}" +MCP_SERVER_NAME = "copilot" +MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__" + +# Context variables to pass user/session info to tool execution +_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None) +_current_session: ContextVar[ChatSession | None] = ContextVar( + "current_session", default=None +) +# Stash for MCP tool outputs before the SDK potentially truncates them. +# Keyed by tool_name → full output string. Consumed (popped) by the +# response adapter when it builds StreamToolOutputAvailable. +_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar( + "pending_tool_outputs", default=None # type: ignore[arg-type] +) + +# Callback type for delegating long-running tools to the non-SDK infrastructure. +# Args: (tool_name, arguments, session) → MCP-formatted response dict. +LongRunningCallback = Callable[ + [str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]] +] + +# ContextVar so the service layer can inject the callback per-request. +_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar( + "long_running_callback", default=None +) + + +def set_execution_context( + user_id: str | None, + session: ChatSession, + long_running_callback: LongRunningCallback | None = None, +) -> None: + """Set the execution context for tool calls. + + This must be called before streaming begins to ensure tools have access + to user_id and session information. + + Args: + user_id: Current user's ID. + session: Current chat session. + long_running_callback: Optional callback to delegate long-running tools + to the non-SDK background infrastructure (stream_registry + Redis). + """ + _current_user_id.set(user_id) + _current_session.set(session) + _pending_tool_outputs.set({}) + _long_running_callback.set(long_running_callback) + + +def get_execution_context() -> tuple[str | None, ChatSession | None]: + """Get the current execution context.""" + return ( + _current_user_id.get(), + _current_session.get(), + ) + + +def pop_pending_tool_output(tool_name: str) -> str | None: + """Pop and return the stashed full output for *tool_name*. + + The SDK CLI may truncate large tool results (writing them to disk and + replacing the content with a file reference). This stash keeps the + original MCP output so the response adapter can forward it to the + frontend for proper widget rendering. + + Returns ``None`` if nothing was stashed for *tool_name*. + """ + pending = _pending_tool_outputs.get(None) + if pending is None: + return None + return pending.pop(tool_name, None) + + +async def _execute_tool_sync( + base_tool: BaseTool, + user_id: str | None, + session: ChatSession, + args: dict[str, Any], +) -> dict[str, Any]: + """Execute a tool synchronously and return MCP-formatted response.""" + effective_id = f"sdk-{uuid.uuid4().hex[:12]}" + result = await base_tool.execute( + user_id=user_id, + session=session, + tool_call_id=effective_id, + **args, + ) + + text = ( + result.output if isinstance(result.output, str) else json.dumps(result.output) + ) + + # Stash the full output before the SDK potentially truncates it. + pending = _pending_tool_outputs.get(None) + if pending is not None: + pending[base_tool.name] = text + + return { + "content": [{"type": "text", "text": text}], + "isError": not result.success, + } + + +def _mcp_error(message: str) -> dict[str, Any]: + return { + "content": [ + {"type": "text", "text": json.dumps({"error": message, "type": "error"})} + ], + "isError": True, + } + + +def create_tool_handler(base_tool: BaseTool): + """Create an async handler function for a BaseTool. + + This wraps the existing BaseTool._execute method to be compatible + with the Claude Agent SDK MCP tool format. + + Long-running tools (``is_long_running=True``) are delegated to the + non-SDK background infrastructure via a callback set in the execution + context. The callback persists the operation in Redis (stream_registry) + so results survive page refreshes and pod restarts. + """ + + async def tool_handler(args: dict[str, Any]) -> dict[str, Any]: + """Execute the wrapped tool and return MCP-formatted response.""" + user_id, session = get_execution_context() + + if session is None: + return _mcp_error("No session context available") + + # --- Long-running: delegate to non-SDK background infrastructure --- + if base_tool.is_long_running: + callback = _long_running_callback.get(None) + if callback: + try: + return await callback(base_tool.name, args, session) + except Exception as e: + logger.error( + f"Long-running callback failed for {base_tool.name}: {e}", + exc_info=True, + ) + return _mcp_error(f"Failed to start {base_tool.name}: {e}") + # No callback — fall through to synchronous execution + logger.warning( + f"[SDK] No long-running callback for {base_tool.name}, " + f"executing synchronously (may block)" + ) + + # --- Normal (fast) tool: execute synchronously --- + try: + return await _execute_tool_sync(base_tool, user_id, session, args) + except Exception as e: + logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True) + return _mcp_error(f"Failed to execute {base_tool.name}: {e}") + + return tool_handler + + +def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]: + """Build a JSON Schema input schema for a tool.""" + return { + "type": "object", + "properties": base_tool.parameters.get("properties", {}), + "required": base_tool.parameters.get("required", []), + } + + +async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]: + """Read a file with optional offset/limit. Restricted to SDK working directory. + + After reading, the file is deleted to prevent accumulation in long-running pods. + """ + file_path = args.get("file_path", "") + offset = args.get("offset", 0) + limit = args.get("limit", 2000) + + # Security: only allow reads under ~/.claude/projects/**/tool-results/ + real_path = os.path.realpath(file_path) + if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path: + return { + "content": [{"type": "text", "text": f"Access denied: {file_path}"}], + "isError": True, + } + + try: + with open(real_path) as f: + selected = list(itertools.islice(f, offset, offset + limit)) + content = "".join(selected) + # Cleanup happens in _cleanup_sdk_tool_results after session ends; + # don't delete here — the SDK may read in multiple chunks. + return {"content": [{"type": "text", "text": content}], "isError": False} + except FileNotFoundError: + return { + "content": [{"type": "text", "text": f"File not found: {file_path}"}], + "isError": True, + } + except Exception as e: + return { + "content": [{"type": "text", "text": f"Error reading file: {e}"}], + "isError": True, + } + + +_READ_TOOL_NAME = "Read" +_READ_TOOL_DESCRIPTION = ( + "Read a file from the local filesystem. " + "Use offset and limit to read specific line ranges for large files." +) +_READ_TOOL_SCHEMA = { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The absolute path to the file to read", + }, + "offset": { + "type": "integer", + "description": "Line number to start reading from (0-indexed). Default: 0", + }, + "limit": { + "type": "integer", + "description": "Number of lines to read. Default: 2000", + }, + }, + "required": ["file_path"], +} + + +# Create the MCP server configuration +def create_copilot_mcp_server(): + """Create an in-process MCP server configuration for CoPilot tools. + + This can be passed to ClaudeAgentOptions.mcp_servers. + + Note: The actual SDK MCP server creation depends on the claude-agent-sdk + package being available. This function returns the configuration that + can be used with the SDK. + """ + try: + from claude_agent_sdk import create_sdk_mcp_server, tool + + # Create decorated tool functions + sdk_tools = [] + + for tool_name, base_tool in TOOL_REGISTRY.items(): + handler = create_tool_handler(base_tool) + decorated = tool( + tool_name, + base_tool.description, + _build_input_schema(base_tool), + )(handler) + sdk_tools.append(decorated) + + # Add the Read tool so the SDK can read back oversized tool results + read_tool = tool( + _READ_TOOL_NAME, + _READ_TOOL_DESCRIPTION, + _READ_TOOL_SCHEMA, + )(_read_file_handler) + sdk_tools.append(read_tool) + + server = create_sdk_mcp_server( + name=MCP_SERVER_NAME, + version="1.0.0", + tools=sdk_tools, + ) + + return server + + except ImportError: + # Let ImportError propagate so service.py handles the fallback + raise + + +# SDK built-in tools allowed within the workspace directory. +# Security hooks validate that file paths stay within sdk_cwd. +# Bash is NOT included — use the sandboxed MCP bash_exec tool instead, +# which provides kernel-level network isolation via unshare --net. +# Task allows spawning sub-agents (rate-limited by security hooks). +# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk. +_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"] + +# SDK built-in tools that must be explicitly blocked. +# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level +# network isolation (unshare --net) instead. +# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.). +# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead. +SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"] + +# Tools that are blocked entirely in security hooks (defence-in-depth). +# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms. +BLOCKED_TOOLS = { + *SDK_DISALLOWED_TOOLS, + "bash", + "shell", + "exec", + "terminal", + "command", +} + +# Tools allowed only when their path argument stays within the SDK workspace. +# The SDK uses these to handle oversized tool results (writes to tool-results/ +# files, then reads them back) and for workspace file operations. +WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"} + +# Dangerous patterns in tool inputs +DANGEROUS_PATTERNS = [ + r"sudo", + r"rm\s+-rf", + r"dd\s+if=", + r"/etc/passwd", + r"/etc/shadow", + r"chmod\s+777", + r"curl\s+.*\|.*sh", + r"wget\s+.*\|.*sh", + r"eval\s*\(", + r"exec\s*\(", + r"__import__", + r"os\.system", + r"subprocess", +] + +# List of tool names for allowed_tools configuration +# Include MCP tools, the MCP Read tool for oversized results, +# and SDK built-in file tools for workspace operations. +COPILOT_TOOL_NAMES = [ + *[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()], + f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}", + *_SDK_BUILTIN_TOOLS, +] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py b/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py new file mode 100644 index 0000000000..aaa5609227 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py @@ -0,0 +1,356 @@ +"""JSONL transcript management for stateless multi-turn resume. + +The Claude Code CLI persists conversations as JSONL files (one JSON object per +line). When the SDK's ``Stop`` hook fires we read this file, strip bloat +(progress entries, metadata), and upload the result to bucket storage. On the +next turn we download the transcript, write it to a temp file, and pass +``--resume`` so the CLI can reconstruct the full conversation. + +Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local +filesystem for self-hosted) — no DB column needed. +""" + +import json +import logging +import os +import re + +logger = logging.getLogger(__name__) + +# UUIDs are hex + hyphens; strip everything else to prevent path injection. +_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]") + +# Entry types that can be safely removed from the transcript without breaking +# the parentUuid conversation tree that ``--resume`` relies on. +# - progress: UI progress ticks, no message content (avg 97KB for agent_progress) +# - file-history-snapshot: undo tracking metadata +# - queue-operation: internal queue bookkeeping +# - summary: session summaries +# - pr-link: PR link metadata +STRIPPABLE_TYPES = frozenset( + {"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"} +) + +# Workspace storage constants — deterministic path from session_id. +TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" + + +# --------------------------------------------------------------------------- +# Progress stripping +# --------------------------------------------------------------------------- + + +def strip_progress_entries(content: str) -> str: + """Remove progress/metadata entries from a JSONL transcript. + + Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents + any remaining child entries so the ``parentUuid`` chain stays intact. + Typically reduces transcript size by ~30%. + """ + lines = content.strip().split("\n") + + entries: list[dict] = [] + for line in lines: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + # Keep unparseable lines as-is (safety) + entries.append({"_raw": line}) + + stripped_uuids: set[str] = set() + uuid_to_parent: dict[str, str] = {} + kept: list[dict] = [] + + for entry in entries: + if "_raw" in entry: + kept.append(entry) + continue + uid = entry.get("uuid", "") + parent = entry.get("parentUuid", "") + entry_type = entry.get("type", "") + + if uid: + uuid_to_parent[uid] = parent + + if entry_type in STRIPPABLE_TYPES: + if uid: + stripped_uuids.add(uid) + else: + kept.append(entry) + + # Reparent: walk up chain through stripped entries to find surviving ancestor + for entry in kept: + if "_raw" in entry: + continue + parent = entry.get("parentUuid", "") + original_parent = parent + while parent in stripped_uuids: + parent = uuid_to_parent.get(parent, "") + if parent != original_parent: + entry["parentUuid"] = parent + + result_lines: list[str] = [] + for entry in kept: + if "_raw" in entry: + result_lines.append(entry["_raw"]) + else: + result_lines.append(json.dumps(entry, separators=(",", ":"))) + + return "\n".join(result_lines) + "\n" + + +# --------------------------------------------------------------------------- +# Local file I/O (read from CLI's JSONL, write temp file for --resume) +# --------------------------------------------------------------------------- + + +def read_transcript_file(transcript_path: str) -> str | None: + """Read a JSONL transcript file from disk. + + Returns the raw JSONL content, or ``None`` if the file is missing, empty, + or only contains metadata (≤2 lines with no conversation messages). + """ + if not transcript_path or not os.path.isfile(transcript_path): + logger.debug(f"[Transcript] File not found: {transcript_path}") + return None + + try: + with open(transcript_path) as f: + content = f.read() + + if not content.strip(): + logger.debug(f"[Transcript] Empty file: {transcript_path}") + return None + + lines = content.strip().split("\n") + if len(lines) < 3: + # Raw files with ≤2 lines are metadata-only + # (queue-operation + file-history-snapshot, no conversation). + logger.debug( + f"[Transcript] Too few lines ({len(lines)}): {transcript_path}" + ) + return None + + # Quick structural validation — parse first and last lines. + json.loads(lines[0]) + json.loads(lines[-1]) + + logger.info( + f"[Transcript] Read {len(lines)} lines, " + f"{len(content)} bytes from {transcript_path}" + ) + return content + + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}") + return None + + +def _sanitize_id(raw_id: str, max_len: int = 36) -> str: + """Sanitize an ID for safe use in file paths. + + Session/user IDs are expected to be UUIDs (hex + hyphens). Strip + everything else and truncate to *max_len* so the result cannot introduce + path separators or other special characters. + """ + cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len] + return cleaned or "unknown" + + +_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") + + +def write_transcript_to_tempfile( + transcript_content: str, + session_id: str, + cwd: str, +) -> str | None: + """Write JSONL transcript to a temp file inside *cwd* for ``--resume``. + + The file lives in the session working directory so it is cleaned up + automatically when the session ends. + + Returns the absolute path to the file, or ``None`` on failure. + """ + # Validate cwd is under the expected sandbox prefix (CodeQL sanitizer). + real_cwd = os.path.realpath(cwd) + if not real_cwd.startswith(_SAFE_CWD_PREFIX): + logger.warning(f"[Transcript] cwd outside sandbox: {cwd}") + return None + + try: + os.makedirs(real_cwd, exist_ok=True) + safe_id = _sanitize_id(session_id, max_len=8) + jsonl_path = os.path.realpath( + os.path.join(real_cwd, f"transcript-{safe_id}.jsonl") + ) + if not jsonl_path.startswith(real_cwd): + logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}") + return None + + with open(jsonl_path, "w") as f: + f.write(transcript_content) + + logger.info(f"[Transcript] Wrote resume file: {jsonl_path}") + return jsonl_path + + except OSError as e: + logger.warning(f"[Transcript] Failed to write resume file: {e}") + return None + + +def validate_transcript(content: str | None) -> bool: + """Check that a transcript has actual conversation messages. + + A valid transcript for resume needs at least one user message and one + assistant message (not just queue-operation / file-history-snapshot + metadata). + """ + if not content or not content.strip(): + return False + + lines = content.strip().split("\n") + if len(lines) < 2: + return False + + has_user = False + has_assistant = False + + for line in lines: + try: + entry = json.loads(line) + msg_type = entry.get("type") + if msg_type == "user": + has_user = True + elif msg_type == "assistant": + has_assistant = True + except json.JSONDecodeError: + return False + + return has_user and has_assistant + + +# --------------------------------------------------------------------------- +# Bucket storage (GCS / local via WorkspaceStorageBackend) +# --------------------------------------------------------------------------- + + +def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: + """Return (workspace_id, file_id, filename) for a session's transcript. + + Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl`` + IDs are sanitized to hex+hyphen to prevent path traversal. + """ + return ( + TRANSCRIPT_STORAGE_PREFIX, + _sanitize_id(user_id), + f"{_sanitize_id(session_id)}.jsonl", + ) + + +def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: + """Build the full storage path string that ``retrieve()`` expects. + + ``store()`` returns a path like ``gcs://bucket/workspaces/...`` or + ``local://workspace_id/file_id/filename``. Since we use deterministic + arguments we can reconstruct the same path for download/delete without + having stored the return value. + """ + from backend.util.workspace_storage import GCSWorkspaceStorage + + wid, fid, fname = _storage_path_parts(user_id, session_id) + + if isinstance(backend, GCSWorkspaceStorage): + blob = f"workspaces/{wid}/{fid}/{fname}" + return f"gcs://{backend.bucket_name}/{blob}" + else: + # LocalWorkspaceStorage returns local://{relative_path} + return f"local://{wid}/{fid}/{fname}" + + +async def upload_transcript(user_id: str, session_id: str, content: str) -> None: + """Strip progress entries and upload transcript to bucket storage. + + Safety: only overwrites when the new (stripped) transcript is larger than + what is already stored. Since JSONL is append-only, the latest transcript + is always the longest. This prevents a slow/stale background task from + clobbering a newer upload from a concurrent turn. + """ + from backend.util.workspace_storage import get_workspace_storage + + stripped = strip_progress_entries(content) + if not validate_transcript(stripped): + logger.warning( + f"[Transcript] Skipping upload — stripped content is not a valid " + f"transcript for session {session_id}" + ) + return + + storage = await get_workspace_storage() + wid, fid, fname = _storage_path_parts(user_id, session_id) + encoded = stripped.encode("utf-8") + new_size = len(encoded) + + # Check existing transcript size to avoid overwriting newer with older + path = _build_storage_path(user_id, session_id, storage) + try: + existing = await storage.retrieve(path) + if len(existing) >= new_size: + logger.info( + f"[Transcript] Skipping upload — existing transcript " + f"({len(existing)}B) >= new ({new_size}B) for session " + f"{session_id}" + ) + return + except (FileNotFoundError, Exception): + pass # No existing transcript or retrieval error — proceed with upload + + await storage.store( + workspace_id=wid, + file_id=fid, + filename=fname, + content=encoded, + ) + logger.info( + f"[Transcript] Uploaded {new_size} bytes " + f"(stripped from {len(content)}) for session {session_id}" + ) + + +async def download_transcript(user_id: str, session_id: str) -> str | None: + """Download transcript from bucket storage. + + Returns the JSONL content string, or ``None`` if not found. + """ + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + data = await storage.retrieve(path) + content = data.decode("utf-8") + logger.info( + f"[Transcript] Downloaded {len(content)} bytes for session {session_id}" + ) + return content + except FileNotFoundError: + logger.debug(f"[Transcript] No transcript in storage for {session_id}") + return None + except Exception as e: + logger.warning(f"[Transcript] Failed to download transcript: {e}") + return None + + +async def delete_transcript(user_id: str, session_id: str) -> None: + """Delete transcript from bucket storage (e.g. after resume failure).""" + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + await storage.delete(path) + logger.info(f"[Transcript] Deleted transcript for session {session_id}") + except Exception as e: + logger.warning(f"[Transcript] Failed to delete transcript: {e}") diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 193566ea01..cb5591e6d0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -245,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str: return DEFAULT_SYSTEM_PROMPT.format(users_information=context) -async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: +async def _build_system_prompt( + user_id: str | None, has_conversation_history: bool = False +) -> tuple[str, Any]: """Build the full system prompt including business understanding if available. Args: - user_id: The user ID for fetching business understanding - If "default" and this is the user's first session, will use "onboarding" instead. + user_id: The user ID for fetching business understanding. + has_conversation_history: Whether there's existing conversation history. + If True, we don't tell the model to greet/introduce (since they're + already in a conversation). Returns: Tuple of (compiled prompt string, business understanding object) @@ -266,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: if understanding: context = format_understanding_for_prompt(understanding) + elif has_conversation_history: + context = "No prior understanding saved yet. Continue the existing conversation naturally." else: context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" @@ -374,7 +380,6 @@ async def stream_chat_completion( Raises: NotFoundError: If session_id is invalid - ValueError: If max_context_messages is exceeded """ completion_start = time.monotonic() @@ -459,8 +464,9 @@ async def stream_chat_completion( # Generate title for new sessions on first user message (non-blocking) # Check: is_user_message, no title yet, and this is the first user message - if is_user_message and message and not session.title: - user_messages = [m for m in session.messages if m.role == "user"] + user_messages = [m for m in session.messages if m.role == "user"] + first_user_msg = message or (user_messages[0].content if user_messages else None) + if is_user_message and first_user_msg and not session.title: if len(user_messages) == 1: # First user message - generate title in background import asyncio @@ -468,7 +474,7 @@ async def stream_chat_completion( # Capture only the values we need (not the session object) to avoid # stale data issues when the main flow modifies the session captured_session_id = session_id - captured_message = message + captured_message = first_user_msg captured_user_id = user_id async def _update_title(): @@ -1237,7 +1243,7 @@ async def _stream_chat_chunks( total_time = (time_module.perf_counter() - stream_chunks_start) * 1000 logger.info( - f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; " + f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; " f"session={session.session_id}, user={session.user_id}", extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, ) @@ -1245,6 +1251,7 @@ async def _stream_chat_chunks( return except Exception as e: last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: retry_count += 1 # Calculate delay with exponential backoff @@ -1260,12 +1267,27 @@ async def _stream_chat_chunks( continue # Retry the stream else: # Non-retryable error or max retries exceeded - logger.error( - f"Error in stream (not retrying): {e!s}", - exc_info=True, + _log_api_error( + error=e, + context="stream (not retrying)", + session_id=session.session_id if session else None, + message_count=len(messages) if messages else None, + model=model, + retry_count=retry_count, ) error_code = None error_text = str(e) + + error_details = _extract_api_error_details(e) + if error_details.get("response_body"): + body = error_details["response_body"] + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict) and err.get("message"): + error_text = err["message"] + elif body.get("message"): + error_text = body["message"] + if _is_region_blocked_error(e): error_code = "MODEL_NOT_AVAILABLE_REGION" error_text = ( @@ -1282,9 +1304,13 @@ async def _stream_chat_chunks( # If we exit the retry loop without returning, it means we exhausted retries if last_error: - logger.error( - f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}", - exc_info=True, + _log_api_error( + error=last_error, + context=f"stream (max retries {MAX_RETRIES} exceeded)", + session_id=session.session_id if session else None, + message_count=len(messages) if messages else None, + model=model, + retry_count=MAX_RETRIES, ) yield StreamError(errorText=f"Max retries exceeded: {last_error!s}") yield StreamFinish() @@ -1857,6 +1883,7 @@ async def _generate_llm_continuation( break # Success, exit retry loop except Exception as e: last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: retry_count += 1 delay = min( @@ -1870,17 +1897,25 @@ async def _generate_llm_continuation( await asyncio.sleep(delay) continue else: - # Non-retryable error - log and exit gracefully - logger.error( - f"Non-retryable error in LLM continuation: {e!s}", - exc_info=True, + # Non-retryable error - log details and exit gracefully + _log_api_error( + error=e, + context="LLM continuation (not retrying)", + session_id=session_id, + message_count=len(messages) if messages else None, + model=config.model, + retry_count=retry_count, ) return if last_error: - logger.error( - f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. " - f"Last error: {last_error!s}" + _log_api_error( + error=last_error, + context=f"LLM continuation (max retries {MAX_RETRIES} exceeded)", + session_id=session_id, + message_count=len(messages) if messages else None, + model=config.model, + retry_count=MAX_RETRIES, ) return @@ -1920,6 +1955,91 @@ async def _generate_llm_continuation( logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) +def _log_api_error( + error: Exception, + context: str, + session_id: str | None = None, + message_count: int | None = None, + model: str | None = None, + retry_count: int = 0, +) -> None: + """Log detailed API error information for debugging.""" + details = _extract_api_error_details(error) + details["context"] = context + details["session_id"] = session_id + details["message_count"] = message_count + details["model"] = model + details["retry_count"] = retry_count + + if isinstance(error, RateLimitError): + logger.warning(f"Rate limit error in {context}: {details}", exc_info=error) + elif isinstance(error, APIConnectionError): + logger.warning(f"API connection error in {context}: {details}", exc_info=error) + elif isinstance(error, APIStatusError) and error.status_code >= 500: + logger.error(f"API server error (5xx) in {context}: {details}", exc_info=error) + else: + logger.error(f"API error in {context}: {details}", exc_info=error) + + +def _extract_api_error_details(error: Exception) -> dict[str, Any]: + """Extract detailed information from OpenAI/OpenRouter API errors.""" + error_msg = str(error) + details: dict[str, Any] = { + "error_type": type(error).__name__, + "error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg, + } + + if hasattr(error, "code"): + details["code"] = getattr(error, "code", None) + if hasattr(error, "param"): + details["param"] = getattr(error, "param", None) + + if isinstance(error, APIStatusError): + details["status_code"] = error.status_code + details["request_id"] = getattr(error, "request_id", None) + + if hasattr(error, "body") and error.body: + details["response_body"] = _sanitize_error_body(error.body) + + if hasattr(error, "response") and error.response: + headers = error.response.headers + details["openrouter_provider"] = headers.get("x-openrouter-provider") + details["openrouter_model"] = headers.get("x-openrouter-model") + details["retry_after"] = headers.get("retry-after") + details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining") + + return details + + +def _sanitize_error_body( + body: Any, max_length: int = 2000 +) -> dict[str, Any] | str | None: + """Extract only safe fields from error response body to avoid logging sensitive data.""" + if not isinstance(body, dict): + # Non-dict bodies (e.g., HTML error pages) - return truncated string + if body is not None: + body_str = str(body) + if len(body_str) > max_length: + return body_str[:max_length] + "...[truncated]" + return body_str + return None + + safe_fields = ("message", "type", "code", "param", "error") + sanitized: dict[str, Any] = {} + + for field in safe_fields: + if field in body: + value = body[field] + if field == "error" and isinstance(value, dict): + sanitized[field] = _sanitize_error_body(value, max_length) + elif isinstance(value, str) and len(value) > max_length: + sanitized[field] = value[:max_length] + "...[truncated]" + else: + sanitized[field] = value + + return sanitized if sanitized else None + + async def _generate_llm_continuation_with_streaming( session_id: str, user_id: str | None, diff --git a/autogpt_platform/backend/backend/api/features/chat/service_test.py b/autogpt_platform/backend/backend/api/features/chat/service_test.py index 70f27af14f..b2fc82b790 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/service_test.py @@ -1,3 +1,4 @@ +import asyncio import logging from os import getenv @@ -11,6 +12,8 @@ from .response_model import ( StreamTextDelta, StreamToolOutputAvailable, ) +from .sdk import service as sdk_service +from .sdk.transcript import download_transcript logger = logging.getLogger(__name__) @@ -80,3 +83,96 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user session = await get_chat_session(session.session_id) assert session, "Session not found" assert session.usage, "Usage is empty" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sdk_resume_multi_turn(setup_test_user, test_user_id): + """Test that the SDK --resume path captures and uses transcripts across turns. + + Turn 1: Send a message containing a unique keyword. + Turn 2: Ask the model to recall that keyword — proving the transcript was + persisted and restored via --resume. + """ + api_key: str | None = getenv("OPEN_ROUTER_API_KEY") + if not api_key: + return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test") + + from .config import ChatConfig + + cfg = ChatConfig() + if not cfg.claude_agent_use_resume: + return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test") + + session = await create_chat_session(test_user_id) + session = await upsert_chat_session(session) + + # --- Turn 1: send a message with a unique keyword --- + keyword = "ZEPHYR42" + turn1_msg = ( + f"Please remember this special keyword: {keyword}. " + "Just confirm you've noted it, keep your response brief." + ) + turn1_text = "" + turn1_errors: list[str] = [] + turn1_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn1_msg, + user_id=test_user_id, + ): + if isinstance(chunk, StreamTextDelta): + turn1_text += chunk.delta + elif isinstance(chunk, StreamError): + turn1_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn1_ended = True + + assert turn1_ended, "Turn 1 did not finish" + assert not turn1_errors, f"Turn 1 errors: {turn1_errors}" + assert turn1_text, "Turn 1 produced no text" + + # Wait for background upload task to complete (retry up to 5s) + transcript = None + for _ in range(10): + await asyncio.sleep(0.5) + transcript = await download_transcript(test_user_id, session.session_id) + if transcript: + break + assert transcript, ( + "Transcript was not uploaded to bucket after turn 1 — " + "Stop hook may not have fired or transcript was too small" + ) + logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes") + + # Reload session for turn 2 + session = await get_chat_session(session.session_id, test_user_id) + assert session, "Session not found after turn 1" + + # --- Turn 2: ask model to recall the keyword --- + turn2_msg = "What was the special keyword I asked you to remember?" + turn2_text = "" + turn2_errors: list[str] = [] + turn2_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn2_msg, + user_id=test_user_id, + session=session, + ): + if isinstance(chunk, StreamTextDelta): + turn2_text += chunk.delta + elif isinstance(chunk, StreamError): + turn2_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn2_ended = True + + assert turn2_ended, "Turn 2 did not finish" + assert not turn2_errors, f"Turn 2 errors: {turn2_errors}" + assert turn2_text, "Turn 2 produced no text" + assert keyword in turn2_text, ( + f"Model did not recall keyword '{keyword}' in turn 2. " + f"Response: {turn2_text[:200]}" + ) + logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}") diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index abc34b1fc9..671aefc7ba 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -814,6 +814,28 @@ async def get_active_task_for_session( if task_user_id and user_id != task_user_id: continue + # Auto-expire stale tasks that exceeded stream_timeout + created_at_str = meta.get("created_at", "") + if created_at_str: + try: + created_at = datetime.fromisoformat(created_at_str) + age_seconds = ( + datetime.now(timezone.utc) - created_at + ).total_seconds() + if age_seconds > config.stream_timeout: + logger.warning( + f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... " + f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)" + ) + await mark_task_completed(task_id, "failed") + continue + except (ValueError, TypeError): + pass + + logger.info( + f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..." + ) + # Get the last message ID from Redis Stream stream_key = _get_task_stream_key(task_id) last_id = "0-0" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index dcbc35ef37..1ab4f720bb 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -9,9 +9,12 @@ from backend.api.features.chat.tracking import track_tool_called from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool +from .bash_exec import BashExecTool +from .check_operation_status import CheckOperationStatusTool from .create_agent import CreateAgentTool from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool +from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool from .find_agent import FindAgentTool from .find_block import FindBlockTool from .find_library_agent import FindLibraryAgentTool @@ -19,6 +22,7 @@ from .get_doc_page import GetDocPageTool from .run_agent import RunAgentTool from .run_block import RunBlockTool from .search_docs import SearchDocsTool +from .web_fetch import WebFetchTool from .workspace_files import ( DeleteWorkspaceFileTool, ListWorkspaceFilesTool, @@ -43,8 +47,17 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "run_agent": RunAgentTool(), "run_block": RunBlockTool(), "view_agent_output": AgentOutputTool(), + "check_operation_status": CheckOperationStatusTool(), "search_docs": SearchDocsTool(), "get_doc_page": GetDocPageTool(), + # Web fetch for safe URL retrieval + "web_fetch": WebFetchTool(), + # Sandboxed code execution (bubblewrap) + "bash_exec": BashExecTool(), + # Persistent workspace tools (cloud storage, survives across sessions) + # Feature request tools + "search_feature_requests": SearchFeatureRequestsTool(), + "create_feature_request": CreateFeatureRequestTool(), # Workspace tools for CoPilot file operations "list_workspace_files": ListWorkspaceFilesTool(), "read_workspace_file": ReadWorkspaceFileTool(), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py b/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py new file mode 100644 index 0000000000..da9d8bf3fa --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py @@ -0,0 +1,131 @@ +"""Bash execution tool — run shell commands in a bubblewrap sandbox. + +Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.). +Safety comes from OS-level isolation (bubblewrap): only system dirs visible +read-only, writable workspace only, clean env, no network. + +Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not +available (e.g. macOS development). +""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + BashExecResponse, + ErrorResponse, + ToolResponseBase, +) +from backend.api.features.chat.tools.sandbox import ( + get_workspace_dir, + has_full_sandbox, + run_sandboxed, +) + +logger = logging.getLogger(__name__) + + +class BashExecTool(BaseTool): + """Execute Bash commands in a bubblewrap sandbox.""" + + @property + def name(self) -> str: + return "bash_exec" + + @property + def description(self) -> str: + if not has_full_sandbox(): + return ( + "Bash execution is DISABLED — bubblewrap sandbox is not " + "available on this platform. Do not call this tool." + ) + return ( + "Execute a Bash command or script in a bubblewrap sandbox. " + "Full Bash scripting is supported (loops, conditionals, pipes, " + "functions, etc.). " + "The sandbox shares the same working directory as the SDK Read/Write " + "tools — files created by either are accessible to both. " + "SECURITY: Only system directories (/usr, /bin, /lib, /etc) are " + "visible read-only, the per-session workspace is the only writable " + "path, environment variables are wiped (no secrets), all network " + "access is blocked at the kernel level, and resource limits are " + "enforced (max 64 processes, 512MB memory, 50MB file size). " + "Application code, configs, and other directories are NOT accessible. " + "To fetch web content, use the web_fetch tool instead. " + "Execution is killed after the timeout (default 30s, max 120s). " + "Returns stdout and stderr. " + "Useful for file manipulation, data processing with Unix tools " + "(grep, awk, sed, jq, etc.), and running shell scripts." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash command or script to execute.", + }, + "timeout": { + "type": "integer", + "description": ( + "Max execution time in seconds (default 30, max 120)." + ), + "default": 30, + }, + }, + "required": ["command"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + session_id = session.session_id if session else None + + if not has_full_sandbox(): + return ErrorResponse( + message="bash_exec requires bubblewrap sandbox (Linux only).", + error="sandbox_unavailable", + session_id=session_id, + ) + + command: str = (kwargs.get("command") or "").strip() + timeout: int = kwargs.get("timeout", 30) + + if not command: + return ErrorResponse( + message="No command provided.", + error="empty_command", + session_id=session_id, + ) + + workspace = get_workspace_dir(session_id or "default") + + stdout, stderr, exit_code, timed_out = await run_sandboxed( + command=["bash", "-c", command], + cwd=workspace, + timeout=timeout, + ) + + return BashExecResponse( + message=( + "Execution timed out" + if timed_out + else f"Command executed (exit {exit_code})" + ), + stdout=stdout, + stderr=stderr, + exit_code=exit_code, + timed_out=timed_out, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py b/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py new file mode 100644 index 0000000000..b8ec770fd0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py @@ -0,0 +1,127 @@ +"""CheckOperationStatusTool — query the status of a long-running operation.""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + ResponseType, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class OperationStatusResponse(ToolResponseBase): + """Response for check_operation_status tool.""" + + type: ResponseType = ResponseType.OPERATION_STATUS + task_id: str + operation_id: str + status: str # "running", "completed", "failed" + tool_name: str | None = None + message: str = "" + + +class CheckOperationStatusTool(BaseTool): + """Check the status of a long-running operation (create_agent, edit_agent, etc.). + + The CoPilot uses this tool to report back to the user whether an + operation that was started earlier has completed, failed, or is still + running. + """ + + @property + def name(self) -> str: + return "check_operation_status" + + @property + def description(self) -> str: + return ( + "Check the current status of a long-running operation such as " + "create_agent or edit_agent. Accepts either an operation_id or " + "task_id from a previous operation_started response. " + "Returns the current status: running, completed, or failed." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "operation_id": { + "type": "string", + "description": ( + "The operation_id from an operation_started response." + ), + }, + "task_id": { + "type": "string", + "description": ( + "The task_id from an operation_started response. " + "Used as fallback if operation_id is not provided." + ), + }, + }, + "required": [], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + from backend.api.features.chat import stream_registry + + operation_id = (kwargs.get("operation_id") or "").strip() + task_id = (kwargs.get("task_id") or "").strip() + + if not operation_id and not task_id: + return ErrorResponse( + message="Please provide an operation_id or task_id.", + error="missing_parameter", + ) + + task = None + if operation_id: + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None and task_id: + task = await stream_registry.get_task(task_id) + + if task is None: + # Task not in Redis — it may have already expired (TTL). + # Check conversation history for the result instead. + return ErrorResponse( + message=( + "Operation not found — it may have already completed and " + "expired from the status tracker. Check the conversation " + "history for the result." + ), + error="not_found", + ) + + status_messages = { + "running": ( + f"The {task.tool_name or 'operation'} is still running. " + "Please wait for it to complete." + ), + "completed": ( + f"The {task.tool_name or 'operation'} has completed successfully." + ), + "failed": f"The {task.tool_name or 'operation'} has failed.", + } + + return OperationStatusResponse( + task_id=task.task_id, + operation_id=task.operation_id, + status=task.status, + tool_name=task.tool_name, + message=status_messages.get(task.status, f"Status: {task.status}"), + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py new file mode 100644 index 0000000000..95f1eb1fbe --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py @@ -0,0 +1,448 @@ +"""Feature request tools - search and create feature requests via Linear.""" + +import logging +from typing import Any + +from pydantic import SecretStr + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestInfo, + FeatureRequestSearchResponse, + NoResultsResponse, + ToolResponseBase, +) +from backend.blocks.linear._api import LinearClient +from backend.data.model import APIKeyCredentials +from backend.data.user import get_user_email_by_id +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + +MAX_SEARCH_RESULTS = 10 + +# GraphQL queries/mutations +SEARCH_ISSUES_QUERY = """ +query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) { + searchIssues(term: $term, filter: $filter, first: $first) { + nodes { + id + identifier + title + description + } + } +} +""" + +CUSTOMER_UPSERT_MUTATION = """ +mutation CustomerUpsert($input: CustomerUpsertInput!) { + customerUpsert(input: $input) { + success + customer { + id + name + externalIds + } + } +} +""" + +ISSUE_CREATE_MUTATION = """ +mutation IssueCreate($input: IssueCreateInput!) { + issueCreate(input: $input) { + success + issue { + id + identifier + title + url + } + } +} +""" + +CUSTOMER_NEED_CREATE_MUTATION = """ +mutation CustomerNeedCreate($input: CustomerNeedCreateInput!) { + customerNeedCreate(input: $input) { + success + need { + id + body + customer { + id + name + } + issue { + id + identifier + title + url + } + } + } +} +""" + + +_settings: Settings | None = None + + +def _get_settings() -> Settings: + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def _get_linear_config() -> tuple[LinearClient, str, str]: + """Return a configured Linear client, project ID, and team ID. + + Raises RuntimeError if any required setting is missing. + """ + secrets = _get_settings().secrets + if not secrets.linear_api_key: + raise RuntimeError("LINEAR_API_KEY is not configured") + if not secrets.linear_feature_request_project_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured") + if not secrets.linear_feature_request_team_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_TEAM_ID is not configured") + + credentials = APIKeyCredentials( + id="system-linear", + provider="linear", + api_key=SecretStr(secrets.linear_api_key), + title="System Linear API Key", + ) + client = LinearClient(credentials=credentials) + return ( + client, + secrets.linear_feature_request_project_id, + secrets.linear_feature_request_team_id, + ) + + +class SearchFeatureRequestsTool(BaseTool): + """Tool for searching existing feature requests in Linear.""" + + @property + def name(self) -> str: + return "search_feature_requests" + + @property + def description(self) -> str: + return ( + "Search existing feature requests to check if a similar request " + "already exists before creating a new one. Returns matching feature " + "requests with their ID, title, and description." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search term to find matching feature requests.", + }, + }, + "required": ["query"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + query = kwargs.get("query", "").strip() + session_id = session.session_id if session else None + + if not query: + return ErrorResponse( + message="Please provide a search query.", + error="Missing query parameter", + session_id=session_id, + ) + + try: + client, project_id, _team_id = _get_linear_config() + data = await client.query( + SEARCH_ISSUES_QUERY, + { + "term": query, + "filter": { + "project": {"id": {"eq": project_id}}, + }, + "first": MAX_SEARCH_RESULTS, + }, + ) + + nodes = data.get("searchIssues", {}).get("nodes", []) + + if not nodes: + return NoResultsResponse( + message=f"No feature requests found matching '{query}'.", + suggestions=[ + "Try different keywords", + "Use broader search terms", + "You can create a new feature request if none exists", + ], + session_id=session_id, + ) + + results = [ + FeatureRequestInfo( + id=node["id"], + identifier=node["identifier"], + title=node["title"], + description=node.get("description"), + ) + for node in nodes + ] + + return FeatureRequestSearchResponse( + message=f"Found {len(results)} feature request(s) matching '{query}'.", + results=results, + count=len(results), + query=query, + session_id=session_id, + ) + except Exception as e: + logger.exception("Failed to search feature requests") + return ErrorResponse( + message="Failed to search feature requests.", + error=str(e), + session_id=session_id, + ) + + +class CreateFeatureRequestTool(BaseTool): + """Tool for creating feature requests (or adding needs to existing ones).""" + + @property + def name(self) -> str: + return "create_feature_request" + + @property + def description(self) -> str: + return ( + "Create a new feature request or add a customer need to an existing one. " + "Always search first with search_feature_requests to avoid duplicates. " + "If a matching request exists, pass its ID as existing_issue_id to add " + "the user's need to it instead of creating a duplicate." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title for the feature request.", + }, + "description": { + "type": "string", + "description": "Detailed description of what the user wants and why.", + }, + "existing_issue_id": { + "type": "string", + "description": ( + "If adding a need to an existing feature request, " + "provide its Linear issue ID (from search results). " + "Omit to create a new feature request." + ), + }, + }, + "required": ["title", "description"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _find_or_create_customer( + self, client: LinearClient, user_id: str, name: str + ) -> dict: + """Find existing customer by user_id or create a new one via upsert. + + Args: + client: Linear API client. + user_id: Stable external ID used to deduplicate customers. + name: Human-readable display name (e.g. the user's email). + """ + data = await client.mutate( + CUSTOMER_UPSERT_MUTATION, + { + "input": { + "name": name, + "externalId": user_id, + }, + }, + ) + result = data.get("customerUpsert", {}) + if not result.get("success"): + raise RuntimeError(f"Failed to upsert customer: {data}") + return result["customer"] + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + title = kwargs.get("title", "").strip() + description = kwargs.get("description", "").strip() + existing_issue_id = kwargs.get("existing_issue_id") + session_id = session.session_id if session else None + + if not title or not description: + return ErrorResponse( + message="Both title and description are required.", + error="Missing required parameters", + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="Authentication required to create feature requests.", + error="Missing user_id", + session_id=session_id, + ) + + try: + client, project_id, team_id = _get_linear_config() + except Exception as e: + logger.exception("Failed to initialize Linear client") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Resolve a human-readable name (email) for the Linear customer record. + # Fall back to user_id if the lookup fails or returns None. + try: + customer_display_name = await get_user_email_by_id(user_id) or user_id + except Exception: + customer_display_name = user_id + + # Step 1: Find or create customer for this user + try: + customer = await self._find_or_create_customer( + client, user_id, customer_display_name + ) + customer_id = customer["id"] + customer_name = customer["name"] + except Exception as e: + logger.exception("Failed to upsert customer in Linear") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Step 2: Create or reuse issue + issue_id: str | None = None + issue_identifier: str | None = None + if existing_issue_id: + # Add need to existing issue - we still need the issue details for response + is_new_issue = False + issue_id = existing_issue_id + else: + # Create new issue in the feature requests project + try: + data = await client.mutate( + ISSUE_CREATE_MUTATION, + { + "input": { + "title": title, + "description": description, + "teamId": team_id, + "projectId": project_id, + }, + }, + ) + result = data.get("issueCreate", {}) + if not result.get("success"): + return ErrorResponse( + message="Failed to create feature request issue.", + error=str(data), + session_id=session_id, + ) + issue = result["issue"] + issue_id = issue["id"] + issue_identifier = issue.get("identifier") + except Exception as e: + logger.exception("Failed to create feature request issue") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + is_new_issue = True + + # Step 3: Create customer need on the issue + try: + data = await client.mutate( + CUSTOMER_NEED_CREATE_MUTATION, + { + "input": { + "customerId": customer_id, + "issueId": issue_id, + "body": description, + "priority": 0, + }, + }, + ) + need_result = data.get("customerNeedCreate", {}) + if not need_result.get("success"): + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(data), + details=orphaned, + session_id=session_id, + ) + need = need_result["need"] + issue_info = need["issue"] + except Exception as e: + logger.exception("Failed to create customer need") + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(e), + details=orphaned, + session_id=session_id, + ) + + return FeatureRequestCreatedResponse( + message=( + f"{'Created new feature request' if is_new_issue else 'Added your request to existing feature request'}: " + f"{issue_info['title']}." + ), + issue_id=issue_info["id"], + issue_identifier=issue_info["identifier"], + issue_title=issue_info["title"], + issue_url=issue_info.get("url", ""), + is_new_issue=is_new_issue, + customer_name=customer_name, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py new file mode 100644 index 0000000000..438725368f --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py @@ -0,0 +1,615 @@ +"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.api.features.chat.tools.feature_requests import ( + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +) +from backend.api.features.chat.tools.models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestSearchResponse, + NoResultsResponse, +) + +from ._test_data import make_session + +_TEST_USER_ID = "test-user-feature-requests" +_TEST_USER_EMAIL = "testuser@example.com" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +_FAKE_PROJECT_ID = "test-project-id" +_FAKE_TEAM_ID = "test-team-id" + + +def _mock_linear_config(*, query_return=None, mutate_return=None): + """Return a patched _get_linear_config that yields a mock LinearClient.""" + client = AsyncMock() + if query_return is not None: + client.query.return_value = query_return + if mutate_return is not None: + client.mutate.return_value = mutate_return + return ( + patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID), + ), + client, + ) + + +def _search_response(nodes: list[dict]) -> dict: + return {"searchIssues": {"nodes": nodes}} + + +def _customer_upsert_response( + customer_id: str = "cust-1", name: str = _TEST_USER_EMAIL, success: bool = True +) -> dict: + return { + "customerUpsert": { + "success": success, + "customer": {"id": customer_id, "name": name, "externalIds": [name]}, + } + } + + +def _issue_create_response( + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "issueCreate": { + "success": success, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + } + } + + +def _need_create_response( + need_id: str = "need-1", + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "customerNeedCreate": { + "success": success, + "need": { + "id": need_id, + "body": "description", + "customer": {"id": "cust-1", "name": _TEST_USER_EMAIL}, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + }, + } + } + + +# =========================================================================== +# SearchFeatureRequestsTool +# =========================================================================== + + +class TestSearchFeatureRequestsTool: + """Tests for SearchFeatureRequestsTool._execute.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_successful_search(self): + session = make_session(user_id=_TEST_USER_ID) + nodes = [ + { + "id": "id-1", + "identifier": "FR-1", + "title": "Dark mode", + "description": "Add dark mode support", + }, + { + "id": "id-2", + "identifier": "FR-2", + "title": "Dark theme", + "description": None, + }, + ] + patcher, _ = _mock_linear_config(query_return=_search_response(nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="dark mode" + ) + + assert isinstance(resp, FeatureRequestSearchResponse) + assert resp.count == 2 + assert resp.results[0].id == "id-1" + assert resp.results[1].identifier == "FR-2" + assert resp.query == "dark mode" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_results(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, _ = _mock_linear_config(query_return=_search_response([])) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="nonexistent" + ) + + assert isinstance(resp, NoResultsResponse) + assert "nonexistent" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_empty_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session, query=" ") + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "query" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_api_failure(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.query.side_effect = RuntimeError("Linear API down") + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Linear API down" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_malformed_node_returns_error(self): + """A node missing required keys should be caught by the try/except.""" + session = make_session(user_id=_TEST_USER_ID) + # Node missing 'identifier' key + bad_nodes = [{"id": "id-1", "title": "Missing identifier"}] + patcher, _ = _mock_linear_config(query_return=_search_response(bad_nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + +# =========================================================================== +# CreateFeatureRequestTool +# =========================================================================== + + +class TestCreateFeatureRequestTool: + """Tests for CreateFeatureRequestTool._execute.""" + + @pytest.fixture(autouse=True) + def _patch_email_lookup(self): + with patch( + "backend.api.features.chat.tools.feature_requests.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TEST_USER_EMAIL, + ): + yield + + # ---- Happy paths ------------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_create_new_issue(self): + """Full happy path: upsert customer -> create issue -> attach need.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + _need_create_response(), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="New Feature", + description="Please add this", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is True + assert resp.issue_identifier == "FR-1" + assert resp.customer_name == _TEST_USER_EMAIL + assert client.mutate.call_count == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_add_need_to_existing_issue(self): + """When existing_issue_id is provided, skip issue creation.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(issue_id="existing-1", identifier="FR-99"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Existing Feature", + description="Me too", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is False + assert resp.issue_id == "existing-1" + # Only 2 mutations: customer upsert + need create (no issue create) + assert client.mutate.call_count == 2 + + # ---- Validation errors ------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_title(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="", + description="some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_description(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Some title", + description="", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_user_id(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=None, + session=session, + title="Some title", + description="Some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "user_id" in resp.error.lower() + + # ---- Linear client init failure ---------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + # ---- Customer upsert failures ------------------------------------------ + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = RuntimeError("Customer API error") + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Customer API error" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.return_value = _customer_upsert_response(success=False) + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_malformed_response(self): + """Customer dict missing 'id' key should be caught.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + # success=True but customer has no 'id' + client.mutate.return_value = { + "customerUpsert": { + "success": True, + "customer": {"name": _TEST_USER_ID}, + } + } + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Issue creation failures ------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Issue create failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Issue create failed" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert "Failed to create feature request issue" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_malformed_response(self): + """issueCreate success=True but missing 'issue' key.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + {"issueCreate": {"success": True}}, # no 'issue' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Customer need attachment failures --------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_new_issue(self): + """Need creation fails after new issue was created -> orphaned issue info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-1", identifier="FR-10"), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Need attach failed" in resp.error + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-1" + assert resp.details["issue_identifier"] == "FR-10" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_existing_issue(self): + """Need creation fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_includes_orphaned_info(self): + """customerNeedCreate returns success=False -> includes orphaned issue.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-2", identifier="FR-20"), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-2" + assert resp.details["issue_identifier"] == "FR-20" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_existing_issue_no_details(self): + """customerNeedCreate fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_malformed_response(self): + """need_result missing 'need' key after success=True.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + {"customerNeedCreate": {"success": True}}, # no 'need' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "issue-1" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py index 55b1c0d510..c51317cb62 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py @@ -146,6 +146,7 @@ class FindBlockTool(BaseTool): id=block_id, name=block.name, description=block.description or "", + categories=[c.value for c in block.categories], ) ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index bd19d590a6..b32f6ca2ce 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -41,6 +41,15 @@ class ResponseType(str, Enum): OPERATION_IN_PROGRESS = "operation_in_progress" # Input validation INPUT_VALIDATION_ERROR = "input_validation_error" + # Web fetch + WEB_FETCH = "web_fetch" + # Code execution + BASH_EXEC = "bash_exec" + # Operation status check + OPERATION_STATUS = "operation_status" + # Feature request types + FEATURE_REQUEST_SEARCH = "feature_request_search" + FEATURE_REQUEST_CREATED = "feature_request_created" # Base response model @@ -335,6 +344,19 @@ class BlockInfoSummary(BaseModel): id: str name: str description: str + categories: list[str] + input_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block inputs", + ) + output_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block outputs", + ) + required_inputs: list[BlockInputFieldInfo] = Field( + default_factory=list, + description="List of input fields for this block", + ) class BlockListResponse(ToolResponseBase): @@ -344,6 +366,10 @@ class BlockListResponse(ToolResponseBase): blocks: list[BlockInfoSummary] count: int query: str + usage_hint: str = Field( + default="To execute a block, call run_block with block_id set to the block's " + "'id' field and input_data containing the fields listed in required_inputs." + ) class BlockDetails(BaseModel): @@ -430,3 +456,55 @@ class AsyncProcessingResponse(ToolResponseBase): status: str = "accepted" # Must be "accepted" for detection operation_id: str | None = None task_id: str | None = None + + +class WebFetchResponse(ToolResponseBase): + """Response for web_fetch tool.""" + + type: ResponseType = ResponseType.WEB_FETCH + url: str + status_code: int + content_type: str + content: str + truncated: bool = False + + +class BashExecResponse(ToolResponseBase): + """Response for bash_exec tool.""" + + type: ResponseType = ResponseType.BASH_EXEC + stdout: str + stderr: str + exit_code: int + timed_out: bool = False + + +# Feature request models +class FeatureRequestInfo(BaseModel): + """Information about a feature request issue.""" + + id: str + identifier: str + title: str + description: str | None = None + + +class FeatureRequestSearchResponse(ToolResponseBase): + """Response for search_feature_requests tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH + results: list[FeatureRequestInfo] + count: int + query: str + + +class FeatureRequestCreatedResponse(ToolResponseBase): + """Response for create_feature_request tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED + issue_id: str + issue_identifier: str + issue_title: str + issue_url: str + is_new_issue: bool # False if added to existing + customer_name: str diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py new file mode 100644 index 0000000000..beb326f909 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py @@ -0,0 +1,265 @@ +"""Sandbox execution utilities for code execution tools. + +Provides filesystem + network isolated command execution using **bubblewrap** +(``bwrap``): whitelist-only filesystem (only system dirs visible read-only), +writable workspace only, clean environment, network blocked. + +Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox` +and refuse to run if bubblewrap is not available. +""" + +import asyncio +import logging +import os +import platform +import shutil + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT = 30 +_MAX_TIMEOUT = 120 + + +# --------------------------------------------------------------------------- +# Sandbox capability detection (cached at first call) +# --------------------------------------------------------------------------- + +_BWRAP_AVAILABLE: bool | None = None + + +def has_full_sandbox() -> bool: + """Return True if bubblewrap is available (filesystem + network isolation). + + On non-Linux platforms (macOS), always returns False. + """ + global _BWRAP_AVAILABLE + if _BWRAP_AVAILABLE is None: + _BWRAP_AVAILABLE = ( + platform.system() == "Linux" and shutil.which("bwrap") is not None + ) + return _BWRAP_AVAILABLE + + +WORKSPACE_PREFIX = "/tmp/copilot-" + + +def make_session_path(session_id: str) -> str: + """Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`. + + Shared by both the SDK working-directory setup and the sandbox tools so + they always resolve to the same directory for a given session. + + Steps: + 1. Strip all characters except ``[A-Za-z0-9-]``. + 2. Construct ``/tmp/copilot-``. + 3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised + sanitizer) to prevent path traversal. + + Raises: + ValueError: If the resulting path escapes the prefix. + """ + import re + + safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id) + if not safe_id: + safe_id = "default" + path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}") + if not path.startswith(WORKSPACE_PREFIX): + raise ValueError(f"Session path escaped prefix: {path}") + return path + + +def get_workspace_dir(session_id: str) -> str: + """Get or create the workspace directory for a session. + + Uses :func:`make_session_path` — the same path the SDK uses — so that + bash_exec shares the workspace with the SDK file tools. + """ + workspace = make_session_path(session_id) + os.makedirs(workspace, exist_ok=True) + return workspace + + +# --------------------------------------------------------------------------- +# Bubblewrap command builder +# --------------------------------------------------------------------------- + +# System directories mounted read-only inside the sandbox. +# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible. +_SYSTEM_RO_BINDS = [ + "/usr", # binaries, libraries, Python interpreter + "/etc", # system config: ld.so, locale, passwd, alternatives +] + +# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems. +# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind +# can't create a symlink target, so we detect and use --symlink instead. +# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2. +_COMPAT_PATHS = [ + ("/bin", "usr/bin"), # -> /usr/bin on Debian 13 + ("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13 + ("/lib", "usr/lib"), # -> /usr/lib on Debian 13 + ("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter +] + +# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse. +# Applied via ulimit inside the sandbox before exec'ing the user command. +_RESOURCE_LIMITS = ( + "ulimit -u 64" # max 64 processes (prevents fork bombs) + " -v 524288" # 512 MB virtual memory + " -f 51200" # 50 MB max file size (1024-byte blocks) + " -n 256" # 256 open file descriptors + " 2>/dev/null" +) + + +def _build_bwrap_command( + command: list[str], cwd: str, env: dict[str, str] +) -> list[str]: + """Build a bubblewrap command with strict filesystem + network isolation. + + Security model: + - **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``, + ``/bin``, ``/lib``) are mounted read-only. Application code (``/app``), + home directories, ``/var``, ``/opt``, etc. are NOT accessible at all. + - **Writable workspace only**: the per-session workspace is the sole + writable path. + - **Clean environment**: ``--clearenv`` wipes all inherited env vars. + Only the explicitly-passed safe env vars are set inside the sandbox. + - **Network isolation**: ``--unshare-net`` blocks all network access. + - **Resource limits**: ulimit caps on processes (64), memory (512MB), + file size (50MB), and open FDs (256) to prevent fork bombs and abuse. + - **New session**: prevents terminal control escape. + - **Die with parent**: prevents orphaned sandbox processes. + """ + cmd = [ + "bwrap", + # Create a new user namespace so bwrap can set up sandboxing + # inside unprivileged Docker containers (no CAP_SYS_ADMIN needed). + "--unshare-user", + # Wipe all inherited environment variables (API keys, secrets, etc.) + "--clearenv", + ] + + # Set only the safe env vars inside the sandbox + for key, value in env.items(): + cmd.extend(["--setenv", key, value]) + + # System directories: read-only + for path in _SYSTEM_RO_BINDS: + cmd.extend(["--ro-bind", path, path]) + + # Compat paths: use --symlink when host path is a symlink (Debian 13), + # --ro-bind when it's a real directory (older distros). + for path, symlink_target in _COMPAT_PATHS: + if os.path.islink(path): + cmd.extend(["--symlink", symlink_target, path]) + elif os.path.exists(path): + cmd.extend(["--ro-bind", path, path]) + + # Wrap the user command with resource limits: + # sh -c 'ulimit ...; exec "$@"' -- + # `exec "$@"` replaces the shell so there's no extra process overhead, + # and properly handles arguments with spaces. + limited_command = [ + "sh", + "-c", + f'{_RESOURCE_LIMITS}; exec "$@"', + "--", + *command, + ] + + cmd.extend( + [ + # Fresh virtual filesystems + "--dev", + "/dev", + "--proc", + "/proc", + "--tmpfs", + "/tmp", + # Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs. + # (workspace lives under /tmp/copilot-) + "--bind", + cwd, + cwd, + # Isolation + "--unshare-net", + "--die-with-parent", + "--new-session", + "--chdir", + cwd, + "--", + *limited_command, + ] + ) + + return cmd + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def run_sandboxed( + command: list[str], + cwd: str, + timeout: int = _DEFAULT_TIMEOUT, + env: dict[str, str] | None = None, +) -> tuple[str, str, int, bool]: + """Run a command inside a bubblewrap sandbox. + + Callers **must** check :func:`has_full_sandbox` before calling this + function. If bubblewrap is not available, this function raises + :class:`RuntimeError` rather than running unsandboxed. + + Returns: + (stdout, stderr, exit_code, timed_out) + """ + if not has_full_sandbox(): + raise RuntimeError( + "run_sandboxed() requires bubblewrap but bwrap is not available. " + "Callers must check has_full_sandbox() before calling this function." + ) + + timeout = min(max(timeout, 1), _MAX_TIMEOUT) + + safe_env = { + "PATH": "/usr/local/bin:/usr/bin:/bin", + "HOME": cwd, + "TMPDIR": cwd, + "LANG": "en_US.UTF-8", + "PYTHONDONTWRITEBYTECODE": "1", + "PYTHONIOENCODING": "utf-8", + } + if env: + safe_env.update(env) + + full_command = _build_bwrap_command(command, cwd, safe_env) + + try: + proc = await asyncio.create_subprocess_exec( + *full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=safe_env, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + stdout = stdout_bytes.decode("utf-8", errors="replace") + stderr = stderr_bytes.decode("utf-8", errors="replace") + return stdout, stderr, proc.returncode or 0, False + except asyncio.TimeoutError: + proc.kill() + await proc.communicate() + return "", f"Execution timed out after {timeout}s", -1, True + + except RuntimeError: + raise + except Exception as e: + return "", f"Sandbox error: {e}", -1, False diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index 80a842bf36..3b2168d09e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -15,6 +15,7 @@ from backend.data.model import ( OAuth2Credentials, ) from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName from backend.util.exceptions import NotFoundError logger = logging.getLogger(__name__) @@ -359,7 +360,7 @@ async def match_user_credentials_to_graph( _, _, ) in aggregated_creds.items(): - # Find first matching credential by provider, type, and scopes + # Find first matching credential by provider, type, scopes, and host/URL matching_cred = next( ( cred @@ -374,6 +375,10 @@ async def match_user_credentials_to_graph( cred.type != "host_scoped" or _credential_is_for_host(cred, credential_requirements) ) + and ( + cred.provider != ProviderName.MCP + or _credential_is_for_mcp_server(cred, credential_requirements) + ) ), None, ) @@ -444,6 +449,22 @@ def _credential_is_for_host( return credential.matches_url(list(requirements.discriminator_values)[0]) +def _credential_is_for_mcp_server( + credential: Credentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if an MCP OAuth credential matches the required server URL.""" + if not requirements.discriminator_values: + return True + + server_url = ( + credential.metadata.get("mcp_server_url") + if isinstance(credential, OAuth2Credentials) + else None + ) + return server_url in requirements.discriminator_values if server_url else False + + async def check_user_has_required_credentials( user_id: str, required_credentials: list[CredentialsMetaInput], diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py b/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py new file mode 100644 index 0000000000..fed7cc11fa --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py @@ -0,0 +1,151 @@ +"""Web fetch tool — safely retrieve public web page content.""" + +import logging +from typing import Any + +import aiohttp +import html2text + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + ToolResponseBase, + WebFetchResponse, +) +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + +# Limits +_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap +_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15) + +# Content types we'll read as text +_TEXT_CONTENT_TYPES = { + "text/html", + "text/plain", + "text/xml", + "text/csv", + "text/markdown", + "application/json", + "application/xml", + "application/xhtml+xml", + "application/rss+xml", + "application/atom+xml", +} + + +def _is_text_content(content_type: str) -> bool: + base = content_type.split(";")[0].strip().lower() + return base in _TEXT_CONTENT_TYPES or base.startswith("text/") + + +def _html_to_text(html: str) -> str: + h = html2text.HTML2Text() + h.ignore_links = False + h.ignore_images = True + h.body_width = 0 + return h.handle(html) + + +class WebFetchTool(BaseTool): + """Safely fetch content from a public URL using SSRF-protected HTTP.""" + + @property + def name(self) -> str: + return "web_fetch" + + @property + def description(self) -> str: + return ( + "Fetch the content of a public web page by URL. " + "Returns readable text extracted from HTML by default. " + "Useful for reading documentation, articles, and API responses. " + "Only supports HTTP/HTTPS GET requests to public URLs " + "(private/internal network addresses are blocked)." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The public HTTP/HTTPS URL to fetch.", + }, + "extract_text": { + "type": "boolean", + "description": ( + "If true (default), extract readable text from HTML. " + "If false, return raw content." + ), + "default": True, + }, + }, + "required": ["url"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + url: str = (kwargs.get("url") or "").strip() + extract_text: bool = kwargs.get("extract_text", True) + session_id = session.session_id if session else None + + if not url: + return ErrorResponse( + message="Please provide a URL to fetch.", + error="missing_url", + session_id=session_id, + ) + + try: + client = Requests(raise_for_status=False, retry_max_attempts=1) + response = await client.get(url, timeout=_REQUEST_TIMEOUT) + except ValueError as e: + # validate_url raises ValueError for SSRF / blocked IPs + return ErrorResponse( + message=f"URL blocked: {e}", + error="url_blocked", + session_id=session_id, + ) + except Exception as e: + logger.warning(f"[web_fetch] Request failed for {url}: {e}") + return ErrorResponse( + message=f"Failed to fetch URL: {e}", + error="fetch_failed", + session_id=session_id, + ) + + content_type = response.headers.get("content-type", "") + if not _is_text_content(content_type): + return ErrorResponse( + message=f"Non-text content type: {content_type.split(';')[0]}", + error="unsupported_content_type", + session_id=session_id, + ) + + raw = response.content[:_MAX_CONTENT_BYTES] + text = raw.decode("utf-8", errors="replace") + + if extract_text and "html" in content_type.lower(): + text = _html_to_text(text) + + return WebFetchResponse( + message=f"Fetched {url}", + url=response.url, + status_code=response.status, + content_type=content_type.split(";")[0].strip(), + content=text, + truncated=False, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py index 03532c8fee..f37d2c80e0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py @@ -88,7 +88,9 @@ class ListWorkspaceFilesTool(BaseTool): @property def description(self) -> str: return ( - "List files in the user's workspace. " + "List files in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read/Glob tools instead. " "Returns file names, paths, sizes, and metadata. " "Optionally filter by path prefix." ) @@ -204,7 +206,9 @@ class ReadWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Read a file from the user's workspace. " + "Read a file from the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read tool instead. " "Specify either file_id or path to identify the file. " "For small text files, returns content directly. " "For large or binary files, returns metadata and a download URL. " @@ -378,7 +382,9 @@ class WriteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Write or create a file in the user's workspace. " + "Write or create a file in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Write tool instead. " "Provide the content as a base64-encoded string. " f"Maximum file size is {Config().max_file_size_mb}MB. " "Files are saved to the current session's folder by default. " @@ -523,7 +529,7 @@ class DeleteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Delete a file from the user's workspace. " + "Delete a file from the user's persistent workspace (cloud storage). " "Specify either file_id or path to identify the file. " "Paths are scoped to the current session by default. " "Use /sessions//... for cross-session access." diff --git a/autogpt_platform/backend/backend/api/features/integrations/router.py b/autogpt_platform/backend/backend/api/features/integrations/router.py index 00500dc8a8..4eacf83e71 100644 --- a/autogpt_platform/backend/backend/api/features/integrations/router.py +++ b/autogpt_platform/backend/backend/api/features/integrations/router.py @@ -1,7 +1,7 @@ import asyncio import logging from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Annotated, List, Literal +from typing import TYPE_CHECKING, Annotated, Any, List, Literal from autogpt_libs.auth import get_user_id from fastapi import ( @@ -14,7 +14,7 @@ from fastapi import ( Security, status, ) -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, Field, SecretStr, model_validator from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from backend.api.features.library.db import set_preset_webhook, update_preset @@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step from backend.data.user import get_user_integrations from backend.executor.utils import add_graph_execution from backend.integrations.ayrshare import AyrshareClient, SocialPlatform -from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.credentials_store import provider_matches +from backend.integrations.creds_manager import ( + IntegrationCredentialsManager, + create_mcp_oauth_handler, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.integrations.webhooks import get_webhook_manager @@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel): scopes: list[str] | None username: str | None host: str | None = Field( - default=None, description="Host pattern for host-scoped credentials" + default=None, + description="Host pattern for host-scoped or MCP server URL for MCP credentials", ) + @model_validator(mode="before") + @classmethod + def _normalize_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.""" + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + + @staticmethod + def get_host(cred: Credentials) -> str | None: + """Extract host from credential: HostScoped host or MCP server URL.""" + if isinstance(cred, HostScopedCredentials): + return cred.host + if isinstance(cred, OAuth2Credentials) and cred.provider in ( + ProviderName.MCP, + ProviderName.MCP.value, + "ProviderName.MCP", + ): + return (cred.metadata or {}).get("mcp_server_url") + return None + @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") async def callback( @@ -179,9 +211,7 @@ async def callback( title=credentials.title, scopes=credentials.scopes, username=credentials.username, - host=( - credentials.host if isinstance(credentials, HostScopedCredentials) else None - ), + host=(CredentialsMetaResponse.get_host(credentials)), ) @@ -199,7 +229,7 @@ async def list_credentials( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -222,7 +252,7 @@ async def list_credentials_by_provider( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -322,7 +352,11 @@ async def delete_credentials( tokens_revoked = None if isinstance(creds, OAuth2Credentials): - handler = _get_provider_oauth_handler(request, provider) + if provider_matches(provider.value, ProviderName.MCP.value): + # MCP uses dynamic per-server OAuth — create handler from metadata + handler = create_mcp_oauth_handler(creds) + else: + handler = _get_provider_oauth_handler(request, provider) tokens_revoked = await handler.revoke_tokens(creds) return CredentialsDeletionResponse(revoked=tokens_revoked) diff --git a/autogpt_platform/backend/backend/api/features/mcp/__init__.py b/autogpt_platform/backend/backend/api/features/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/api/features/mcp/routes.py b/autogpt_platform/backend/backend/api/features/mcp/routes.py new file mode 100644 index 0000000000..f8d311f372 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/routes.py @@ -0,0 +1,404 @@ +""" +MCP (Model Context Protocol) API routes. + +Provides endpoints for MCP tool discovery and OAuth authentication so the +frontend can list available tools on an MCP server before placing a block. +""" + +import logging +from typing import Annotated, Any +from urllib.parse import urlparse + +import fastapi +from autogpt_libs.auth import get_user_id +from fastapi import Security +from pydantic import BaseModel, Field + +from backend.api.features.integrations.router import CredentialsMetaResponse +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials +from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName +from backend.util.request import HTTPClientError, Requests +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + +settings = Settings() +router = fastapi.APIRouter(tags=["mcp"]) +creds_manager = IntegrationCredentialsManager() + + +# ====================== Tool Discovery ====================== # + + +class DiscoverToolsRequest(BaseModel): + """Request to discover tools on an MCP server.""" + + server_url: str = Field(description="URL of the MCP server") + auth_token: str | None = Field( + default=None, + description="Optional Bearer token for authenticated MCP servers", + ) + + +class MCPToolResponse(BaseModel): + """A single MCP tool returned by discovery.""" + + name: str + description: str + input_schema: dict[str, Any] + + +class DiscoverToolsResponse(BaseModel): + """Response containing the list of tools available on an MCP server.""" + + tools: list[MCPToolResponse] + server_name: str | None = None + protocol_version: str | None = None + + +@router.post( + "/discover-tools", + summary="Discover available tools on an MCP server", + response_model=DiscoverToolsResponse, +) +async def discover_tools( + request: DiscoverToolsRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> DiscoverToolsResponse: + """ + Connect to an MCP server and return its available tools. + + If the user has a stored MCP credential for this server URL, it will be + used automatically — no need to pass an explicit auth token. + """ + auth_token = request.auth_token + + # Auto-use stored MCP credential when no explicit token is provided. + if not auth_token: + mcp_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + # Find the freshest credential for this server URL + best_cred: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == request.server_url + ): + if best_cred is None or ( + (cred.access_token_expires_at or 0) + > (best_cred.access_token_expires_at or 0) + ): + best_cred = cred + if best_cred: + # Refresh the token if expired before using it + best_cred = await creds_manager.refresh_if_needed(user_id, best_cred) + logger.info( + f"Using MCP credential {best_cred.id} for {request.server_url}, " + f"expires_at={best_cred.access_token_expires_at}" + ) + auth_token = best_cred.access_token.get_secret_value() + + client = MCPClient(request.server_url, auth_token=auth_token) + + try: + init_result = await client.initialize() + tools = await client.list_tools() + except HTTPClientError as e: + if e.status_code in (401, 403): + raise fastapi.HTTPException( + status_code=401, + detail="This MCP server requires authentication. " + "Please provide a valid auth token.", + ) + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except MCPClientError as e: + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except Exception as e: + raise fastapi.HTTPException( + status_code=502, + detail=f"Failed to connect to MCP server: {e}", + ) + + return DiscoverToolsResponse( + tools=[ + MCPToolResponse( + name=t.name, + description=t.description, + input_schema=t.input_schema, + ) + for t in tools + ], + server_name=( + init_result.get("serverInfo", {}).get("name") + or urlparse(request.server_url).hostname + or "MCP" + ), + protocol_version=init_result.get("protocolVersion"), + ) + + +# ======================== OAuth Flow ======================== # + + +class MCPOAuthLoginRequest(BaseModel): + """Request to start an OAuth flow for an MCP server.""" + + server_url: str = Field(description="URL of the MCP server that requires OAuth") + + +class MCPOAuthLoginResponse(BaseModel): + """Response with the OAuth login URL for the user to authenticate.""" + + login_url: str + state_token: str + + +@router.post( + "/oauth/login", + summary="Initiate OAuth login for an MCP server", +) +async def mcp_oauth_login( + request: MCPOAuthLoginRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> MCPOAuthLoginResponse: + """ + Discover OAuth metadata from the MCP server and return a login URL. + + 1. Discovers the protected-resource metadata (RFC 9728) + 2. Fetches the authorization server metadata (RFC 8414) + 3. Performs Dynamic Client Registration (RFC 7591) if available + 4. Returns the authorization URL for the frontend to open in a popup + """ + client = MCPClient(request.server_url) + + # Step 1: Discover protected-resource metadata (RFC 9728) + protected_resource = await client.discover_auth() + + metadata: dict[str, Any] | None = None + + if protected_resource and protected_resource.get("authorization_servers"): + auth_server_url = protected_resource["authorization_servers"][0] + resource_url = protected_resource.get("resource", request.server_url) + + # Step 2a: Discover auth-server metadata (RFC 8414) + metadata = await client.discover_auth_server_metadata(auth_server_url) + else: + # Fallback: Some MCP servers (e.g. Linear) are their own auth server + # and serve OAuth metadata directly without protected-resource metadata. + # Don't assume a resource_url — omitting it lets the auth server choose + # the correct audience for the token (RFC 8707 resource is optional). + resource_url = None + metadata = await client.discover_auth_server_metadata(request.server_url) + + if ( + not metadata + or "authorization_endpoint" not in metadata + or "token_endpoint" not in metadata + ): + raise fastapi.HTTPException( + status_code=400, + detail="This MCP server does not advertise OAuth support. " + "You may need to provide an auth token manually.", + ) + + authorize_url = metadata["authorization_endpoint"] + token_url = metadata["token_endpoint"] + registration_endpoint = metadata.get("registration_endpoint") + revoke_url = metadata.get("revocation_endpoint") + + # Step 3: Dynamic Client Registration (RFC 7591) if available + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + client_id = "" + client_secret = "" + if registration_endpoint: + reg_result = await _register_mcp_client( + registration_endpoint, redirect_uri, request.server_url + ) + if reg_result: + client_id = reg_result.get("client_id", "") + client_secret = reg_result.get("client_secret", "") + + if not client_id: + client_id = "autogpt-platform" + + # Step 4: Store state token with OAuth metadata for the callback + scopes = (protected_resource or {}).get("scopes_supported") or metadata.get( + "scopes_supported", [] + ) + state_token, code_challenge = await creds_manager.store.store_state_token( + user_id, + ProviderName.MCP.value, + scopes, + state_metadata={ + "authorize_url": authorize_url, + "token_url": token_url, + "revoke_url": revoke_url, + "resource_url": resource_url, + "server_url": request.server_url, + "client_id": client_id, + "client_secret": client_secret, + }, + ) + + # Step 5: Build and return the login URL + handler = MCPOAuthHandler( + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + authorize_url=authorize_url, + token_url=token_url, + resource_url=resource_url, + ) + login_url = handler.get_login_url( + scopes, state_token, code_challenge=code_challenge + ) + + return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token) + + +class MCPOAuthCallbackRequest(BaseModel): + """Request to exchange an OAuth code for tokens.""" + + code: str = Field(description="Authorization code from OAuth callback") + state_token: str = Field(description="State token for CSRF verification") + + +class MCPOAuthCallbackResponse(BaseModel): + """Response after successfully storing OAuth credentials.""" + + credential_id: str + + +@router.post( + "/oauth/callback", + summary="Exchange OAuth code for MCP tokens", +) +async def mcp_oauth_callback( + request: MCPOAuthCallbackRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> CredentialsMetaResponse: + """ + Exchange the authorization code for tokens and store the credential. + + The frontend calls this after receiving the OAuth code from the popup. + On success, subsequent ``/discover-tools`` calls for the same server URL + will automatically use the stored credential. + """ + valid_state = await creds_manager.store.verify_state_token( + user_id, request.state_token, ProviderName.MCP.value + ) + if not valid_state: + raise fastapi.HTTPException( + status_code=400, + detail="Invalid or expired state token.", + ) + + meta = valid_state.state_metadata + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + handler = MCPOAuthHandler( + client_id=meta["client_id"], + client_secret=meta.get("client_secret", ""), + redirect_uri=redirect_uri, + authorize_url=meta["authorize_url"], + token_url=meta["token_url"], + revoke_url=meta.get("revoke_url"), + resource_url=meta.get("resource_url"), + ) + + try: + credentials = await handler.exchange_code_for_tokens( + request.code, valid_state.scopes, valid_state.code_verifier + ) + except Exception as e: + raise fastapi.HTTPException( + status_code=400, + detail=f"OAuth token exchange failed: {e}", + ) + + # Enrich credential metadata for future lookup and token refresh + if credentials.metadata is None: + credentials.metadata = {} + credentials.metadata["mcp_server_url"] = meta["server_url"] + credentials.metadata["mcp_client_id"] = meta["client_id"] + credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "") + credentials.metadata["mcp_token_url"] = meta["token_url"] + credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "") + + hostname = urlparse(meta["server_url"]).hostname or meta["server_url"] + credentials.title = f"MCP: {hostname}" + + # Remove old MCP credentials for the same server to prevent stale token buildup. + try: + old_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + for old in old_creds: + if ( + isinstance(old, OAuth2Credentials) + and (old.metadata or {}).get("mcp_server_url") == meta["server_url"] + ): + await creds_manager.store.delete_creds_by_id(user_id, old.id) + logger.info( + f"Removed old MCP credential {old.id} for {meta['server_url']}" + ) + except Exception: + logger.debug("Could not clean up old MCP credentials", exc_info=True) + + await creds_manager.create(user_id, credentials) + + return CredentialsMetaResponse( + id=credentials.id, + provider=credentials.provider, + type=credentials.type, + title=credentials.title, + scopes=credentials.scopes, + username=credentials.username, + host=credentials.metadata.get("mcp_server_url"), + ) + + +# ======================== Helpers ======================== # + + +async def _register_mcp_client( + registration_endpoint: str, + redirect_uri: str, + server_url: str, +) -> dict[str, Any] | None: + """Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server.""" + try: + response = await Requests(raise_for_status=True).post( + registration_endpoint, + json={ + "client_name": "AutoGPT Platform", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + }, + ) + data = response.json() + if isinstance(data, dict) and "client_id" in data: + return data + return None + except Exception as e: + logger.warning(f"Dynamic client registration failed for {server_url}: {e}") + return None diff --git a/autogpt_platform/backend/backend/api/features/mcp/test_routes.py b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py new file mode 100644 index 0000000000..e86b9f4865 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py @@ -0,0 +1,436 @@ +"""Tests for MCP API routes. + +Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient +to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop. +""" + +from unittest.mock import AsyncMock, patch + +import fastapi +import httpx +import pytest +import pytest_asyncio +from autogpt_libs.auth import get_user_id + +from backend.api.features.mcp.routes import router +from backend.blocks.mcp.client import MCPClientError, MCPTool +from backend.util.request import HTTPClientError + +app = fastapi.FastAPI() +app.include_router(router) +app.dependency_overrides[get_user_id] = lambda: "test-user-id" + + +@pytest_asyncio.fixture(scope="module") +async def client(): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestDiscoverTools: + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_success(self, client): + mock_tools = [ + MCPTool( + name="get_weather", + description="Get weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + MCPTool( + name="add_numbers", + description="Add two numbers", + input_schema={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + }, + ), + ] + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={ + "protocolVersion": "2025-03-26", + "serverInfo": {"name": "test-server"}, + } + ) + instance.list_tools = AsyncMock(return_value=mock_tools) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["tools"]) == 2 + assert data["tools"][0]["name"] == "get_weather" + assert data["tools"][1]["name"] == "add_numbers" + assert data["server_name"] == "test-server" + assert data["protocol_version"] == "2025-03-26" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_with_auth_token(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={ + "server_url": "https://mcp.example.com/mcp", + "auth_token": "my-secret-token", + }, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="my-secret-token", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auto_uses_stored_credential(self, client): + """When no explicit token is given, stored MCP credentials are used.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + stored_cred = OAuth2Credentials( + provider="mcp", + title="MCP: example.com", + access_token=SecretStr("stored-token-123"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={"mcp_server_url": "https://mcp.example.com/mcp"}, + ) + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred]) + mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="stored-token-123", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_mcp_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=MCPClientError("Connection refused") + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://bad-server.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Connection refused" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_generic_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock(side_effect=Exception("Network timeout")) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://timeout.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Failed to connect" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auth_required(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_forbidden(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_missing_url(self, client): + response = await client.post("/discover-tools", json={}) + assert response.status_code == 422 + + +class TestOAuthLogin: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_success(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch( + "backend.api.features.mcp.routes._register_mcp_client" + ) as mock_register, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.sentry.io"], + "resource": "https://mcp.sentry.dev/mcp", + "scopes_supported": ["openid"], + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.sentry.io/authorize", + "token_endpoint": "https://auth.sentry.io/token", + "registration_endpoint": "https://auth.sentry.io/register", + } + ) + mock_register.return_value = { + "client_id": "registered-client-id", + "client_secret": "registered-secret", + } + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-token-123", "code-challenge-abc") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.sentry.dev/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "login_url" in data + assert data["state_token"] == "state-token-123" + assert "auth.sentry.io/authorize" in data["login_url"] + assert "registered-client-id" in data["login_url"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_no_oauth_support(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.discover_auth = AsyncMock(return_value=None) + instance.discover_auth_server_metadata = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/login", + json={"server_url": "https://simple-server.example.com/mcp"}, + ) + + assert response.status_code == 400 + assert "does not advertise OAuth" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_fallback_to_public_client(self, client): + """When DCR is unavailable, falls back to default public client ID.""" + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + # No registration_endpoint + } + ) + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-abc", "challenge-xyz") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "autogpt-platform" in data["login_url"] + + +class TestOAuthCallback: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_success(self, client): + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + mock_creds = OAuth2Credentials( + provider="mcp", + title=None, + access_token=SecretStr("access-token-xyz"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={ + "mcp_token_url": "https://auth.sentry.io/token", + "mcp_resource_url": "https://mcp.sentry.dev/mcp", + }, + ) + + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + + # Mock state verification + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.sentry.io/authorize", + "token_url": "https://auth.sentry.io/token", + "client_id": "test-client-id", + "client_secret": "test-secret", + "server_url": "https://mcp.sentry.dev/mcp", + } + mock_state.scopes = ["openid"] + mock_state.code_verifier = "verifier-123" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + mock_cm.create = AsyncMock() + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + return_value=mock_creds + ) + + # Mock old credential cleanup + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code-abc", "state_token": "state-token-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["provider"] == "mcp" + assert data["type"] == "oauth2" + mock_cm.create.assert_called_once() + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_invalid_state(self, client): + with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm: + mock_cm.store.verify_state_token = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code", "state_token": "bad-state"}, + ) + + assert response.status_code == 400 + assert "Invalid or expired" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_token_exchange_fails(self, client): + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "server_url": "https://mcp.example.com/mcp", + } + mock_state.scopes = [] + mock_state.code_verifier = "v" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + side_effect=RuntimeError("Token exchange failed") + ) + + response = await client.post( + "/oauth/callback", + json={"code": "bad-code", "state_token": "state"}, + ) + + assert response.status_code == 400 + assert "token exchange failed" in response.json()["detail"].lower() diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index da87d53391..5da18a15c3 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -27,6 +27,7 @@ import backend.api.features.executions.review.routes import backend.api.features.library.db import backend.api.features.library.model import backend.api.features.library.routes +import backend.api.features.mcp.routes as mcp_routes import backend.api.features.oauth import backend.api.features.otto.routes import backend.api.features.postmark.postmark @@ -372,6 +373,11 @@ app.include_router( tags=["workspace"], prefix="/api/workspace", ) +app.include_router( + mcp_routes.router, + tags=["v2", "mcp"], + prefix="/api/mcp", +) app.include_router( backend.api.features.oauth.router, tags=["oauth"], diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 2d6fd7a764..01f690a122 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -64,6 +64,7 @@ class BlockType(Enum): AI = "AI" AYRSHARE = "Ayrshare" HUMAN_IN_THE_LOOP = "Human In The Loop" + MCP_TOOL = "MCP Tool" class BlockCategory(Enum): diff --git a/autogpt_platform/backend/backend/blocks/basic.py b/autogpt_platform/backend/backend/blocks/basic.py index f129d2707b..5fdcfb6d82 100644 --- a/autogpt_platform/backend/backend/blocks/basic.py +++ b/autogpt_platform/backend/backend/blocks/basic.py @@ -126,6 +126,7 @@ class PrintToConsoleBlock(Block): output_schema=PrintToConsoleBlock.Output, test_input={"text": "Hello, World!"}, is_sensitive_action=True, + disabled=True, # Disabled per Nick Tindle's request (OPEN-3000) test_output=[ ("output", "Hello, World!"), ("status", "printed"), diff --git a/autogpt_platform/backend/backend/blocks/data_manipulation.py b/autogpt_platform/backend/backend/blocks/data_manipulation.py index a8f25ecb18..fe878acfa9 100644 --- a/autogpt_platform/backend/backend/blocks/data_manipulation.py +++ b/autogpt_platform/backend/backend/blocks/data_manipulation.py @@ -682,17 +682,219 @@ class ListIsEmptyBlock(Block): yield "is_empty", len(input_data.list) == 0 +# ============================================================================= +# List Concatenation Helpers +# ============================================================================= + + +def _validate_list_input(item: Any, index: int) -> str | None: + """Validate that an item is a list. Returns error message or None.""" + if item is None: + return None # None is acceptable, will be skipped + if not isinstance(item, list): + return ( + f"Invalid input at index {index}: expected a list, " + f"got {type(item).__name__}. " + f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])." + ) + return None + + +def _validate_all_lists(lists: List[Any]) -> str | None: + """Validate that all items in a sequence are lists. Returns first error or None.""" + for idx, item in enumerate(lists): + error = _validate_list_input(item, idx) + if error is not None and item is not None: + return error + return None + + +def _concatenate_lists_simple(lists: List[List[Any]]) -> List[Any]: + """Concatenate a sequence of lists into a single list, skipping None values.""" + result: List[Any] = [] + for lst in lists: + if lst is None: + continue + result.extend(lst) + return result + + +def _flatten_nested_list(nested: List[Any], max_depth: int = -1) -> List[Any]: + """ + Recursively flatten a nested list structure. + + Args: + nested: The list to flatten. + max_depth: Maximum recursion depth. -1 means unlimited. + + Returns: + A flat list with all nested elements extracted. + """ + result: List[Any] = [] + _flatten_recursive(nested, result, current_depth=0, max_depth=max_depth) + return result + + +_MAX_FLATTEN_DEPTH = 1000 + + +def _flatten_recursive( + items: List[Any], + result: List[Any], + current_depth: int, + max_depth: int, +) -> None: + """Internal recursive helper for flattening nested lists.""" + if current_depth > _MAX_FLATTEN_DEPTH: + raise RecursionError( + f"Flattening exceeded maximum depth of {_MAX_FLATTEN_DEPTH} levels. " + "Input may be too deeply nested." + ) + for item in items: + if isinstance(item, list) and (max_depth == -1 or current_depth < max_depth): + _flatten_recursive(item, result, current_depth + 1, max_depth) + else: + result.append(item) + + +def _deduplicate_list(items: List[Any]) -> List[Any]: + """ + Remove duplicate elements from a list, preserving order of first occurrences. + + Args: + items: The list to deduplicate. + + Returns: + A list with duplicates removed, maintaining original order. + """ + seen: set = set() + result: List[Any] = [] + for item in items: + item_id = _make_hashable(item) + if item_id not in seen: + seen.add(item_id) + result.append(item) + return result + + +def _make_hashable(item: Any): + """ + Create a hashable representation of any item for deduplication. + Converts unhashable types (dicts, lists) into deterministic tuple structures. + """ + if isinstance(item, dict): + return tuple( + sorted( + ((_make_hashable(k), _make_hashable(v)) for k, v in item.items()), + key=lambda x: (str(type(x[0])), str(x[0])), + ) + ) + if isinstance(item, (list, tuple)): + return tuple(_make_hashable(i) for i in item) + if isinstance(item, set): + return frozenset(_make_hashable(i) for i in item) + return item + + +def _filter_none_values(items: List[Any]) -> List[Any]: + """Remove None values from a list.""" + return [item for item in items if item is not None] + + +def _compute_nesting_depth( + items: Any, current: int = 0, max_depth: int = _MAX_FLATTEN_DEPTH +) -> int: + """ + Compute the maximum nesting depth of a list structure using iteration to avoid RecursionError. + + Uses a stack-based approach to handle deeply nested structures without hitting Python's + recursion limit (~1000 levels). + """ + if not isinstance(items, list): + return current + + # Stack contains tuples of (item, depth) + stack = [(items, current)] + max_observed_depth = current + + while stack: + item, depth = stack.pop() + + if depth > max_depth: + return depth + + if not isinstance(item, list): + max_observed_depth = max(max_observed_depth, depth) + continue + + if len(item) == 0: + max_observed_depth = max(max_observed_depth, depth + 1) + continue + + # Add all children to stack with incremented depth + for child in item: + stack.append((child, depth + 1)) + + return max_observed_depth + + +def _interleave_lists(lists: List[List[Any]]) -> List[Any]: + """ + Interleave elements from multiple lists in round-robin fashion. + Example: [[1,2,3], [a,b], [x,y,z]] -> [1, a, x, 2, b, y, 3, z] + """ + if not lists: + return [] + filtered = [lst for lst in lists if lst is not None] + if not filtered: + return [] + result: List[Any] = [] + max_len = max(len(lst) for lst in filtered) + for i in range(max_len): + for lst in filtered: + if i < len(lst): + result.append(lst[i]) + return result + + +# ============================================================================= +# List Concatenation Blocks +# ============================================================================= + + class ConcatenateListsBlock(Block): + """ + Concatenates two or more lists into a single list. + + This block accepts a list of lists and combines all their elements + in order into one flat output list. It supports options for + deduplication and None-filtering to provide flexible list merging + capabilities for workflow pipelines. + """ + class Input(BlockSchemaInput): lists: List[List[Any]] = SchemaField( description="A list of lists to concatenate together. All lists will be combined in order into a single list.", placeholder="e.g., [[1, 2], [3, 4], [5, 6]]", ) + deduplicate: bool = SchemaField( + description="If True, remove duplicate elements from the concatenated result while preserving order.", + default=False, + advanced=True, + ) + remove_none: bool = SchemaField( + description="If True, remove None values from the concatenated result.", + default=False, + advanced=True, + ) class Output(BlockSchemaOutput): concatenated_list: List[Any] = SchemaField( description="The concatenated list containing all elements from all input lists in order." ) + length: int = SchemaField( + description="The total number of elements in the concatenated list." + ) error: str = SchemaField( description="Error message if concatenation failed due to invalid input types." ) @@ -700,7 +902,7 @@ class ConcatenateListsBlock(Block): def __init__(self): super().__init__( id="3cf9298b-5817-4141-9d80-7c2cc5199c8e", - description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order.", + description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order. Supports optional deduplication and None removal.", categories={BlockCategory.BASIC}, input_schema=ConcatenateListsBlock.Input, output_schema=ConcatenateListsBlock.Output, @@ -709,29 +911,497 @@ class ConcatenateListsBlock(Block): {"lists": [["a", "b"], ["c"], ["d", "e", "f"]]}, {"lists": [[1, 2], []]}, {"lists": []}, + {"lists": [[1, 2, 2, 3], [3, 4]], "deduplicate": True}, + {"lists": [[1, None, 2], [None, 3]], "remove_none": True}, ], test_output=[ ("concatenated_list", [1, 2, 3, 4, 5, 6]), + ("length", 6), ("concatenated_list", ["a", "b", "c", "d", "e", "f"]), + ("length", 6), ("concatenated_list", [1, 2]), + ("length", 2), ("concatenated_list", []), + ("length", 0), + ("concatenated_list", [1, 2, 3, 4]), + ("length", 4), + ("concatenated_list", [1, 2, 3]), + ("length", 3), ], ) + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _perform_concatenation(self, lists: List[List[Any]]) -> List[Any]: + return _concatenate_lists_simple(lists) + + def _apply_deduplication(self, items: List[Any]) -> List[Any]: + return _deduplicate_list(items) + + def _apply_none_removal(self, items: List[Any]) -> List[Any]: + return _filter_none_values(items) + + def _post_process( + self, items: List[Any], deduplicate: bool, remove_none: bool + ) -> List[Any]: + """Apply all post-processing steps to the concatenated result.""" + result = items + if remove_none: + result = self._apply_none_removal(result) + if deduplicate: + result = self._apply_deduplication(result) + return result + async def run(self, input_data: Input, **kwargs) -> BlockOutput: - concatenated = [] - for idx, lst in enumerate(input_data.lists): - if lst is None: - # Skip None values to avoid errors - continue - if not isinstance(lst, list): - # Type validation: each item must be a list - # Strings are iterable and would cause extend() to iterate character-by-character - # Non-iterable types would raise TypeError - yield "error", ( - f"Invalid input at index {idx}: expected a list, got {type(lst).__name__}. " - f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])." - ) - return - concatenated.extend(lst) - yield "concatenated_list", concatenated + # Validate all inputs are lists + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + # Perform concatenation + concatenated = self._perform_concatenation(input_data.lists) + + # Apply post-processing + result = self._post_process( + concatenated, input_data.deduplicate, input_data.remove_none + ) + + yield "concatenated_list", result + yield "length", len(result) + + +class FlattenListBlock(Block): + """ + Flattens a nested list structure into a single flat list. + + This block takes a list that may contain nested lists at any depth + and produces a single-level list with all leaf elements. Useful + for normalizing data structures from multiple sources that may + have varying levels of nesting. + """ + + class Input(BlockSchemaInput): + nested_list: List[Any] = SchemaField( + description="A potentially nested list to flatten into a single-level list.", + placeholder="e.g., [[1, [2, 3]], [4, [5, [6]]]]", + ) + max_depth: int = SchemaField( + description="Maximum depth to flatten. -1 means flatten completely. 1 means flatten only one level.", + default=-1, + advanced=True, + ) + + class Output(BlockSchemaOutput): + flattened_list: List[Any] = SchemaField( + description="The flattened list with all nested elements extracted." + ) + length: int = SchemaField( + description="The number of elements in the flattened list." + ) + original_depth: int = SchemaField( + description="The maximum nesting depth of the original input list." + ) + error: str = SchemaField(description="Error message if flattening failed.") + + def __init__(self): + super().__init__( + id="cc45bb0f-d035-4756-96a7-fe3e36254b4d", + description="Flattens a nested list structure into a single flat list. Supports configurable maximum flattening depth.", + categories={BlockCategory.BASIC}, + input_schema=FlattenListBlock.Input, + output_schema=FlattenListBlock.Output, + test_input=[ + {"nested_list": [[1, 2], [3, [4, 5]]]}, + {"nested_list": [1, [2, [3, [4]]]]}, + {"nested_list": [1, [2, [3, [4]]], 5], "max_depth": 1}, + {"nested_list": []}, + {"nested_list": [1, 2, 3]}, + ], + test_output=[ + ("flattened_list", [1, 2, 3, 4, 5]), + ("length", 5), + ("original_depth", 3), + ("flattened_list", [1, 2, 3, 4]), + ("length", 4), + ("original_depth", 4), + ("flattened_list", [1, 2, [3, [4]], 5]), + ("length", 4), + ("original_depth", 4), + ("flattened_list", []), + ("length", 0), + ("original_depth", 1), + ("flattened_list", [1, 2, 3]), + ("length", 3), + ("original_depth", 1), + ], + ) + + def _compute_depth(self, items: List[Any]) -> int: + """Compute the nesting depth of the input list.""" + return _compute_nesting_depth(items) + + def _flatten(self, items: List[Any], max_depth: int) -> List[Any]: + """Flatten the list to the specified depth.""" + return _flatten_nested_list(items, max_depth=max_depth) + + def _validate_max_depth(self, max_depth: int) -> str | None: + """Validate the max_depth parameter.""" + if max_depth < -1: + return f"max_depth must be -1 (unlimited) or a non-negative integer, got {max_depth}" + return None + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + # Validate max_depth + depth_error = self._validate_max_depth(input_data.max_depth) + if depth_error is not None: + yield "error", depth_error + return + + original_depth = self._compute_depth(input_data.nested_list) + flattened = self._flatten(input_data.nested_list, input_data.max_depth) + + yield "flattened_list", flattened + yield "length", len(flattened) + yield "original_depth", original_depth + + +class InterleaveListsBlock(Block): + """ + Interleaves elements from multiple lists in round-robin fashion. + + Given multiple input lists, this block takes one element from each + list in turn, producing an output where elements alternate between + sources. Lists of different lengths are handled gracefully - shorter + lists simply stop contributing once exhausted. + """ + + class Input(BlockSchemaInput): + lists: List[List[Any]] = SchemaField( + description="A list of lists to interleave. Elements will be taken in round-robin order.", + placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]", + ) + + class Output(BlockSchemaOutput): + interleaved_list: List[Any] = SchemaField( + description="The interleaved list with elements alternating from each input list." + ) + length: int = SchemaField( + description="The total number of elements in the interleaved list." + ) + error: str = SchemaField(description="Error message if interleaving failed.") + + def __init__(self): + super().__init__( + id="9f616084-1d9f-4f8e-bc00-5b9d2a75cd75", + description="Interleaves elements from multiple lists in round-robin fashion, alternating between sources.", + categories={BlockCategory.BASIC}, + input_schema=InterleaveListsBlock.Input, + output_schema=InterleaveListsBlock.Output, + test_input=[ + {"lists": [[1, 2, 3], ["a", "b", "c"]]}, + {"lists": [[1, 2, 3], ["a", "b"], ["x", "y", "z"]]}, + {"lists": [[1], [2], [3]]}, + {"lists": []}, + ], + test_output=[ + ("interleaved_list", [1, "a", 2, "b", 3, "c"]), + ("length", 6), + ("interleaved_list", [1, "a", "x", 2, "b", "y", 3, "z"]), + ("length", 8), + ("interleaved_list", [1, 2, 3]), + ("length", 3), + ("interleaved_list", []), + ("length", 0), + ], + ) + + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _interleave(self, lists: List[List[Any]]) -> List[Any]: + return _interleave_lists(lists) + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + result = self._interleave(input_data.lists) + yield "interleaved_list", result + yield "length", len(result) + + +class ZipListsBlock(Block): + """ + Zips multiple lists together into a list of grouped tuples/lists. + + Takes two or more input lists and combines corresponding elements + into sub-lists. For example, zipping [1,2,3] and ['a','b','c'] + produces [[1,'a'], [2,'b'], [3,'c']]. Supports both truncating + to shortest list and padding to longest list with a fill value. + """ + + class Input(BlockSchemaInput): + lists: List[List[Any]] = SchemaField( + description="A list of lists to zip together. Corresponding elements will be grouped.", + placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]", + ) + pad_to_longest: bool = SchemaField( + description="If True, pad shorter lists with fill_value to match the longest list. If False, truncate to shortest.", + default=False, + advanced=True, + ) + fill_value: Any = SchemaField( + description="Value to use for padding when pad_to_longest is True.", + default=None, + advanced=True, + ) + + class Output(BlockSchemaOutput): + zipped_list: List[List[Any]] = SchemaField( + description="The zipped list of grouped elements." + ) + length: int = SchemaField( + description="The number of groups in the zipped result." + ) + error: str = SchemaField(description="Error message if zipping failed.") + + def __init__(self): + super().__init__( + id="0d0e684f-5cb9-4c4b-b8d1-47a0860e0c07", + description="Zips multiple lists together into a list of grouped elements. Supports padding to longest or truncating to shortest.", + categories={BlockCategory.BASIC}, + input_schema=ZipListsBlock.Input, + output_schema=ZipListsBlock.Output, + test_input=[ + {"lists": [[1, 2, 3], ["a", "b", "c"]]}, + {"lists": [[1, 2, 3], ["a", "b"]]}, + { + "lists": [[1, 2], ["a", "b", "c"]], + "pad_to_longest": True, + "fill_value": 0, + }, + {"lists": []}, + ], + test_output=[ + ("zipped_list", [[1, "a"], [2, "b"], [3, "c"]]), + ("length", 3), + ("zipped_list", [[1, "a"], [2, "b"]]), + ("length", 2), + ("zipped_list", [[1, "a"], [2, "b"], [0, "c"]]), + ("length", 3), + ("zipped_list", []), + ("length", 0), + ], + ) + + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _zip_truncate(self, lists: List[List[Any]]) -> List[List[Any]]: + """Zip lists, truncating to shortest.""" + filtered = [lst for lst in lists if lst is not None] + if not filtered: + return [] + return [list(group) for group in zip(*filtered)] + + def _zip_pad(self, lists: List[List[Any]], fill_value: Any) -> List[List[Any]]: + """Zip lists, padding shorter ones with fill_value.""" + if not lists: + return [] + lists = [lst for lst in lists if lst is not None] + if not lists: + return [] + max_len = max(len(lst) for lst in lists) + result: List[List[Any]] = [] + for i in range(max_len): + group: List[Any] = [] + for lst in lists: + if i < len(lst): + group.append(lst[i]) + else: + group.append(fill_value) + result.append(group) + return result + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + if not input_data.lists: + yield "zipped_list", [] + yield "length", 0 + return + + if input_data.pad_to_longest: + result = self._zip_pad(input_data.lists, input_data.fill_value) + else: + result = self._zip_truncate(input_data.lists) + + yield "zipped_list", result + yield "length", len(result) + + +class ListDifferenceBlock(Block): + """ + Computes the difference between two lists (elements in the first + list that are not in the second list). + + This is useful for finding items that exist in one dataset but + not in another, such as finding new items, missing items, or + items that need to be processed. + """ + + class Input(BlockSchemaInput): + list_a: List[Any] = SchemaField( + description="The primary list to check elements from.", + placeholder="e.g., [1, 2, 3, 4, 5]", + ) + list_b: List[Any] = SchemaField( + description="The list to subtract. Elements found here will be removed from list_a.", + placeholder="e.g., [3, 4, 5, 6]", + ) + symmetric: bool = SchemaField( + description="If True, compute symmetric difference (elements in either list but not both).", + default=False, + advanced=True, + ) + + class Output(BlockSchemaOutput): + difference: List[Any] = SchemaField( + description="Elements from list_a not found in list_b (or symmetric difference if enabled)." + ) + length: int = SchemaField( + description="The number of elements in the difference result." + ) + error: str = SchemaField(description="Error message if the operation failed.") + + def __init__(self): + super().__init__( + id="05309873-9d61-447e-96b5-b804e2511829", + description="Computes the difference between two lists. Returns elements in the first list not found in the second, or symmetric difference.", + categories={BlockCategory.BASIC}, + input_schema=ListDifferenceBlock.Input, + output_schema=ListDifferenceBlock.Output, + test_input=[ + {"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]}, + { + "list_a": [1, 2, 3, 4, 5], + "list_b": [3, 4, 5, 6, 7], + "symmetric": True, + }, + {"list_a": ["a", "b", "c"], "list_b": ["b"]}, + {"list_a": [], "list_b": [1, 2, 3]}, + ], + test_output=[ + ("difference", [1, 2]), + ("length", 2), + ("difference", [1, 2, 6, 7]), + ("length", 4), + ("difference", ["a", "c"]), + ("length", 2), + ("difference", []), + ("length", 0), + ], + ) + + def _compute_difference(self, list_a: List[Any], list_b: List[Any]) -> List[Any]: + """Compute elements in list_a not in list_b.""" + b_hashes = {_make_hashable(item) for item in list_b} + return [item for item in list_a if _make_hashable(item) not in b_hashes] + + def _compute_symmetric_difference( + self, list_a: List[Any], list_b: List[Any] + ) -> List[Any]: + """Compute elements in either list but not both.""" + a_hashes = {_make_hashable(item) for item in list_a} + b_hashes = {_make_hashable(item) for item in list_b} + only_in_a = [item for item in list_a if _make_hashable(item) not in b_hashes] + only_in_b = [item for item in list_b if _make_hashable(item) not in a_hashes] + return only_in_a + only_in_b + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + if input_data.symmetric: + result = self._compute_symmetric_difference( + input_data.list_a, input_data.list_b + ) + else: + result = self._compute_difference(input_data.list_a, input_data.list_b) + + yield "difference", result + yield "length", len(result) + + +class ListIntersectionBlock(Block): + """ + Computes the intersection of two lists (elements present in both lists). + + This is useful for finding common items between two datasets, + such as shared tags, mutual connections, or overlapping categories. + """ + + class Input(BlockSchemaInput): + list_a: List[Any] = SchemaField( + description="The first list to intersect.", + placeholder="e.g., [1, 2, 3, 4, 5]", + ) + list_b: List[Any] = SchemaField( + description="The second list to intersect.", + placeholder="e.g., [3, 4, 5, 6, 7]", + ) + + class Output(BlockSchemaOutput): + intersection: List[Any] = SchemaField( + description="Elements present in both list_a and list_b." + ) + length: int = SchemaField( + description="The number of elements in the intersection." + ) + error: str = SchemaField(description="Error message if the operation failed.") + + def __init__(self): + super().__init__( + id="b6eb08b6-dbe3-411b-b9b4-2508cb311a1f", + description="Computes the intersection of two lists, returning only elements present in both.", + categories={BlockCategory.BASIC}, + input_schema=ListIntersectionBlock.Input, + output_schema=ListIntersectionBlock.Output, + test_input=[ + {"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]}, + {"list_a": ["a", "b", "c"], "list_b": ["c", "d", "e"]}, + {"list_a": [1, 2], "list_b": [3, 4]}, + {"list_a": [], "list_b": [1, 2, 3]}, + ], + test_output=[ + ("intersection", [3, 4, 5]), + ("length", 3), + ("intersection", ["c"]), + ("length", 1), + ("intersection", []), + ("length", 0), + ("intersection", []), + ("length", 0), + ], + ) + + def _compute_intersection(self, list_a: List[Any], list_b: List[Any]) -> List[Any]: + """Compute elements present in both lists, preserving order from list_a.""" + b_hashes = {_make_hashable(item) for item in list_b} + seen: set = set() + result: List[Any] = [] + for item in list_a: + h = _make_hashable(item) + if h in b_hashes and h not in seen: + result.append(item) + seen.add(h) + return result + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + result = self._compute_intersection(input_data.list_a, input_data.list_b) + yield "intersection", result + yield "length", len(result) diff --git a/autogpt_platform/backend/backend/blocks/jina/search.py b/autogpt_platform/backend/backend/blocks/jina/search.py index 22a883fa03..5e58ddcab4 100644 --- a/autogpt_platform/backend/backend/blocks/jina/search.py +++ b/autogpt_platform/backend/backend/blocks/jina/search.py @@ -17,6 +17,7 @@ from backend.blocks.jina._auth import ( from backend.blocks.search import GetRequest from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError +from backend.util.request import HTTPClientError, HTTPServerError, validate_url class SearchTheWebBlock(Block, GetRequest): @@ -110,7 +111,12 @@ class ExtractWebsiteContentBlock(Block, GetRequest): self, input_data: Input, *, credentials: JinaCredentials, **kwargs ) -> BlockOutput: if input_data.raw_content: - url = input_data.url + try: + parsed_url, _, _ = await validate_url(input_data.url, []) + url = parsed_url.geturl() + except ValueError as e: + yield "error", f"Invalid URL: {e}" + return headers = {} else: url = f"https://r.jina.ai/{input_data.url}" @@ -119,5 +125,20 @@ class ExtractWebsiteContentBlock(Block, GetRequest): "Authorization": f"Bearer {credentials.api_key.get_secret_value()}", } - content = await self.get_request(url, json=False, headers=headers) + try: + content = await self.get_request(url, json=False, headers=headers) + except HTTPClientError as e: + yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}" + return + except HTTPServerError as e: + yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}" + return + except Exception as e: + yield "error", f"Failed to fetch {input_data.url}: {e}" + return + + if not content: + yield "error", f"No content returned for {input_data.url}" + return + yield "content", content diff --git a/autogpt_platform/backend/backend/blocks/mcp/__init__.py b/autogpt_platform/backend/backend/blocks/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/blocks/mcp/block.py b/autogpt_platform/backend/backend/blocks/mcp/block.py new file mode 100644 index 0000000000..9e3056d928 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/block.py @@ -0,0 +1,300 @@ +""" +MCP (Model Context Protocol) Tool Block. + +A single dynamic block that can connect to any MCP server, discover available tools, +and execute them. Works like AgentExecutorBlock — the user selects a tool from a +dropdown and the input/output schema adapts dynamically. +""" + +import json +import logging +from typing import Any, Literal + +from pydantic import SecretStr + +from backend.blocks._base import ( + Block, + BlockCategory, + BlockSchemaInput, + BlockSchemaOutput, + BlockType, +) +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.data.block import BlockInput, BlockOutput +from backend.data.model import ( + CredentialsField, + CredentialsMetaInput, + OAuth2Credentials, + SchemaField, +) +from backend.integrations.providers import ProviderName +from backend.util.json import validate_with_jsonschema + +logger = logging.getLogger(__name__) + +TEST_CREDENTIALS = OAuth2Credentials( + id="test-mcp-cred", + provider="mcp", + access_token=SecretStr("mock-mcp-token"), + refresh_token=SecretStr("mock-refresh"), + scopes=[], + title="Mock MCP credential", +) +TEST_CREDENTIALS_INPUT = { + "provider": TEST_CREDENTIALS.provider, + "id": TEST_CREDENTIALS.id, + "type": TEST_CREDENTIALS.type, + "title": TEST_CREDENTIALS.title, +} + + +MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]] + + +class MCPToolBlock(Block): + """ + A block that connects to an MCP server, lets the user pick a tool, + and executes it with dynamic input/output schema. + + The flow: + 1. User provides an MCP server URL (and optional credentials) + 2. Frontend calls the backend to get tool list from that URL + 3. User selects a tool from a dropdown (available_tools) + 4. The block's input schema updates to reflect the selected tool's parameters + 5. On execution, the block calls the MCP server to run the tool + """ + + class Input(BlockSchemaInput): + server_url: str = SchemaField( + description="URL of the MCP server (Streamable HTTP endpoint)", + placeholder="https://mcp.example.com/mcp", + ) + credentials: MCPCredentials = CredentialsField( + discriminator="server_url", + description="MCP server OAuth credentials", + default={}, + ) + selected_tool: str = SchemaField( + description="The MCP tool to execute", + placeholder="Select a tool", + default="", + ) + tool_input_schema: dict[str, Any] = SchemaField( + description="JSON Schema for the selected tool's input parameters. " + "Populated automatically when a tool is selected.", + default={}, + hidden=True, + ) + + tool_arguments: dict[str, Any] = SchemaField( + description="Arguments to pass to the selected MCP tool. " + "The fields here are defined by the tool's input schema.", + default={}, + ) + + @classmethod + def get_input_schema(cls, data: BlockInput) -> dict[str, Any]: + """Return the tool's input schema so the builder UI renders dynamic fields.""" + return data.get("tool_input_schema", {}) + + @classmethod + def get_input_defaults(cls, data: BlockInput) -> BlockInput: + """Return the current tool_arguments as defaults for the dynamic fields.""" + return data.get("tool_arguments", {}) + + @classmethod + def get_missing_input(cls, data: BlockInput) -> set[str]: + """Check which required tool arguments are missing.""" + required_fields = cls.get_input_schema(data).get("required", []) + tool_arguments = data.get("tool_arguments", {}) + return set(required_fields) - set(tool_arguments) + + @classmethod + def get_mismatch_error(cls, data: BlockInput) -> str | None: + """Validate tool_arguments against the tool's input schema.""" + tool_schema = cls.get_input_schema(data) + if not tool_schema: + return None + tool_arguments = data.get("tool_arguments", {}) + return validate_with_jsonschema(tool_schema, tool_arguments) + + class Output(BlockSchemaOutput): + result: Any = SchemaField(description="The result returned by the MCP tool") + error: str = SchemaField(description="Error message if the tool call failed") + + def __init__(self): + super().__init__( + id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", + description="Connect to any MCP server and execute its tools. " + "Provide a server URL, select a tool, and pass arguments dynamically.", + categories={BlockCategory.DEVELOPER_TOOLS}, + input_schema=MCPToolBlock.Input, + output_schema=MCPToolBlock.Output, + block_type=BlockType.MCP_TOOL, + test_credentials=TEST_CREDENTIALS, + test_input={ + "server_url": "https://mcp.example.com/mcp", + "credentials": TEST_CREDENTIALS_INPUT, + "selected_tool": "get_weather", + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + }, + test_output=[ + ( + "result", + {"weather": "sunny", "temperature": 20}, + ), + ], + test_mock={ + "_call_mcp_tool": lambda *a, **kw: { + "weather": "sunny", + "temperature": 20, + }, + }, + ) + + async def _call_mcp_tool( + self, + server_url: str, + tool_name: str, + arguments: dict[str, Any], + auth_token: str | None = None, + ) -> Any: + """Call a tool on the MCP server. Extracted for easy mocking in tests.""" + client = MCPClient(server_url, auth_token=auth_token) + await client.initialize() + result = await client.call_tool(tool_name, arguments) + + if result.is_error: + error_text = "" + for item in result.content: + if item.get("type") == "text": + error_text += item.get("text", "") + raise MCPClientError( + f"MCP tool '{tool_name}' returned an error: " + f"{error_text or 'Unknown error'}" + ) + + # Extract text content from the result + output_parts = [] + for item in result.content: + if item.get("type") == "text": + text = item.get("text", "") + # Try to parse as JSON for structured output + try: + output_parts.append(json.loads(text)) + except (json.JSONDecodeError, ValueError): + output_parts.append(text) + elif item.get("type") == "image": + output_parts.append( + { + "type": "image", + "data": item.get("data"), + "mimeType": item.get("mimeType"), + } + ) + elif item.get("type") == "resource": + output_parts.append(item.get("resource", {})) + + # If single result, unwrap + if len(output_parts) == 1: + return output_parts[0] + return output_parts if output_parts else None + + @staticmethod + async def _auto_lookup_credential( + user_id: str, server_url: str + ) -> "OAuth2Credentials | None": + """Auto-lookup stored MCP credential for a server URL. + + This is a fallback for nodes that don't have ``credentials`` explicitly + set (e.g. nodes created before the credential field was wired up). + """ + from backend.integrations.creds_manager import IntegrationCredentialsManager + from backend.integrations.providers import ProviderName + + try: + mgr = IntegrationCredentialsManager() + mcp_creds = await mgr.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + best: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == server_url + ): + if best is None or ( + (cred.access_token_expires_at or 0) + > (best.access_token_expires_at or 0) + ): + best = cred + if best: + best = await mgr.refresh_if_needed(user_id, best) + logger.info( + "Auto-resolved MCP credential %s for %s", best.id, server_url + ) + return best + except Exception: + logger.warning("Auto-lookup MCP credential failed", exc_info=True) + return None + + async def run( + self, + input_data: Input, + *, + user_id: str, + credentials: OAuth2Credentials | None = None, + **kwargs, + ) -> BlockOutput: + if not input_data.server_url: + yield "error", "MCP server URL is required" + return + + if not input_data.selected_tool: + yield "error", "No tool selected. Please select a tool from the dropdown." + return + + # Validate required tool arguments before calling the server. + # The executor-level validation is bypassed for MCP blocks because + # get_input_defaults() flattens tool_arguments, stripping tool_input_schema + # from the validation context. + required = set(input_data.tool_input_schema.get("required", [])) + if required: + missing = required - set(input_data.tool_arguments.keys()) + if missing: + yield "error", ( + f"Missing required argument(s): {', '.join(sorted(missing))}. " + f"Please fill in all required fields marked with * in the block form." + ) + return + + # If no credentials were injected by the executor (e.g. legacy nodes + # that don't have the credentials field set), try to auto-lookup + # the stored MCP credential for this server URL. + if credentials is None: + credentials = await self._auto_lookup_credential( + user_id, input_data.server_url + ) + + auth_token = ( + credentials.access_token.get_secret_value() if credentials else None + ) + + try: + result = await self._call_mcp_tool( + server_url=input_data.server_url, + tool_name=input_data.selected_tool, + arguments=input_data.tool_arguments, + auth_token=auth_token, + ) + yield "result", result + except MCPClientError as e: + yield "error", str(e) + except Exception as e: + logger.exception(f"MCP tool call failed: {e}") + yield "error", f"MCP tool call failed: {str(e)}" diff --git a/autogpt_platform/backend/backend/blocks/mcp/client.py b/autogpt_platform/backend/backend/blocks/mcp/client.py new file mode 100644 index 0000000000..050349dbcc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/client.py @@ -0,0 +1,323 @@ +""" +MCP (Model Context Protocol) HTTP client. + +Implements the MCP Streamable HTTP transport for listing tools and calling tools +on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST. + +Handles both JSON and SSE (text/event-stream) response formats per the MCP spec. + +Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports +""" + +import json +import logging +from dataclasses import dataclass, field +from typing import Any + +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +@dataclass +class MCPTool: + """Represents an MCP tool discovered from a server.""" + + name: str + description: str + input_schema: dict[str, Any] + + +@dataclass +class MCPCallResult: + """Result from calling an MCP tool.""" + + content: list[dict[str, Any]] = field(default_factory=list) + is_error: bool = False + + +class MCPClientError(Exception): + """Raised when an MCP protocol error occurs.""" + + pass + + +class MCPClient: + """ + Async HTTP client for the MCP Streamable HTTP transport. + + Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST. + Supports optional Bearer token authentication. + """ + + def __init__( + self, + server_url: str, + auth_token: str | None = None, + ): + self.server_url = server_url.rstrip("/") + self.auth_token = auth_token + self._request_id = 0 + self._session_id: str | None = None + + def _next_id(self) -> int: + self._request_id += 1 + return self._request_id + + def _build_headers(self) -> dict[str, str]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + if self._session_id: + headers["Mcp-Session-Id"] = self._session_id + return headers + + def _build_jsonrpc_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + req: dict[str, Any] = { + "jsonrpc": "2.0", + "method": method, + "id": self._next_id(), + } + if params is not None: + req["params"] = params + return req + + @staticmethod + def _parse_sse_response(text: str) -> dict[str, Any]: + """Parse an SSE (text/event-stream) response body into JSON-RPC data. + + MCP servers may return responses as SSE with format: + event: message + data: {"jsonrpc":"2.0","result":{...},"id":1} + + We extract the last `data:` line that contains a JSON-RPC response + (i.e. has an "id" field), which is the reply to our request. + """ + last_data: dict[str, Any] | None = None + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("data:"): + payload = stripped[len("data:") :].strip() + if not payload: + continue + try: + parsed = json.loads(payload) + # Only keep JSON-RPC responses (have "id"), skip notifications + if isinstance(parsed, dict) and "id" in parsed: + last_data = parsed + except (json.JSONDecodeError, ValueError): + continue + if last_data is None: + raise MCPClientError("No JSON-RPC response found in SSE stream") + return last_data + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> Any: + """Send a JSON-RPC request to the MCP server and return the result. + + Handles both ``application/json`` and ``text/event-stream`` responses + as required by the MCP Streamable HTTP transport specification. + """ + payload = self._build_jsonrpc_request(method, params) + headers = self._build_headers() + + requests = Requests( + raise_for_status=True, + extra_headers=headers, + ) + response = await requests.post(self.server_url, json=payload) + + # Capture session ID from response (MCP Streamable HTTP transport) + session_id = response.headers.get("Mcp-Session-Id") + if session_id: + self._session_id = session_id + + content_type = response.headers.get("content-type", "") + if "text/event-stream" in content_type: + body = self._parse_sse_response(response.text()) + else: + try: + body = response.json() + except Exception as e: + raise MCPClientError( + f"MCP server returned non-JSON response: {e}" + ) from e + + if not isinstance(body, dict): + raise MCPClientError( + f"MCP server returned unexpected JSON type: {type(body).__name__}" + ) + + # Handle JSON-RPC error + if "error" in body: + error = body["error"] + if isinstance(error, dict): + raise MCPClientError( + f"MCP server error [{error.get('code', '?')}]: " + f"{error.get('message', 'Unknown error')}" + ) + raise MCPClientError(f"MCP server error: {error}") + + return body.get("result") + + async def _send_notification(self, method: str) -> None: + """Send a JSON-RPC notification (no id, no response expected).""" + headers = self._build_headers() + notification = {"jsonrpc": "2.0", "method": method} + requests = Requests( + raise_for_status=False, + extra_headers=headers, + ) + await requests.post(self.server_url, json=notification) + + async def discover_auth(self) -> dict[str, Any] | None: + """Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec). + + Returns ``None`` if the server doesn't require auth, otherwise returns + a dict with: + - ``authorization_servers``: list of authorization server URLs + - ``resource``: the resource indicator URL (usually the MCP endpoint) + - ``scopes_supported``: optional list of supported scopes + + The caller can then fetch the authorization server metadata to get + ``authorization_endpoint``, ``token_endpoint``, etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(self.server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + + # Build candidates for protected-resource metadata (per RFC 9728) + path = parsed.path.rstrip("/") + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-protected-resource{path}") + candidates.append(f"{base}/.well-known/oauth-protected-resource") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_servers" in data: + return data + except Exception: + continue + + return None + + async def discover_auth_server_metadata( + self, auth_server_url: str + ) -> dict[str, Any] | None: + """Fetch the OAuth Authorization Server Metadata (RFC 8414). + + Given an authorization server URL, returns a dict with: + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``registration_endpoint`` (for dynamic client registration) + - ``scopes_supported`` + - ``code_challenge_methods_supported`` + - etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(auth_server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + + # Try standard metadata endpoints (RFC 8414 and OpenID Connect) + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-authorization-server{path}") + candidates.append(f"{base}/.well-known/oauth-authorization-server") + candidates.append(f"{base}/.well-known/openid-configuration") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_endpoint" in data: + return data + except Exception: + continue + + return None + + async def initialize(self) -> dict[str, Any]: + """ + Send the MCP initialize request. + + This is required by the MCP protocol before any other requests. + Returns the server's capabilities. + """ + result = await self._send_request( + "initialize", + { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"}, + }, + ) + # Send initialized notification (no response expected) + await self._send_notification("notifications/initialized") + + return result or {} + + async def list_tools(self) -> list[MCPTool]: + """ + Discover available tools from the MCP server. + + Returns a list of MCPTool objects with name, description, and input schema. + """ + result = await self._send_request("tools/list") + if not result or "tools" not in result: + return [] + + tools = [] + for tool_data in result["tools"]: + tools.append( + MCPTool( + name=tool_data.get("name", ""), + description=tool_data.get("description", ""), + input_schema=tool_data.get("inputSchema", {}), + ) + ) + return tools + + async def call_tool( + self, tool_name: str, arguments: dict[str, Any] + ) -> MCPCallResult: + """ + Call a tool on the MCP server. + + Args: + tool_name: The name of the tool to call. + arguments: The arguments to pass to the tool. + + Returns: + MCPCallResult with the tool's response content. + """ + result = await self._send_request( + "tools/call", + {"name": tool_name, "arguments": arguments}, + ) + if not result: + return MCPCallResult(is_error=True) + + return MCPCallResult( + content=result.get("content", []), + is_error=result.get("isError", False), + ) diff --git a/autogpt_platform/backend/backend/blocks/mcp/oauth.py b/autogpt_platform/backend/backend/blocks/mcp/oauth.py new file mode 100644 index 0000000000..2228336cd3 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/oauth.py @@ -0,0 +1,204 @@ +""" +MCP OAuth handler for MCP servers that use OAuth 2.1 authorization. + +Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed, +MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata. +This handler accepts those endpoints at construction time. +""" + +import logging +import time +import urllib.parse +from typing import ClassVar, Optional + +from pydantic import SecretStr + +from backend.data.model import OAuth2Credentials +from backend.integrations.oauth.base import BaseOAuthHandler +from backend.integrations.providers import ProviderName +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +class MCPOAuthHandler(BaseOAuthHandler): + """ + OAuth handler for MCP servers with dynamically-discovered endpoints. + + Construction requires the authorization and token endpoint URLs, + which are obtained via MCP OAuth metadata discovery + (``MCPClient.discover_auth`` + ``discover_auth_server_metadata``). + """ + + PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP + DEFAULT_SCOPES: ClassVar[list[str]] = [] + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_uri: str, + *, + authorize_url: str, + token_url: str, + revoke_url: str | None = None, + resource_url: str | None = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + self.authorize_url = authorize_url + self.token_url = token_url + self.revoke_url = revoke_url + self.resource_url = resource_url + + def get_login_url( + self, + scopes: list[str], + state: str, + code_challenge: Optional[str], + ) -> str: + scopes = self.handle_default_scopes(scopes) + + params: dict[str, str] = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "state": state, + } + if scopes: + params["scope"] = " ".join(scopes) + # PKCE (S256) — included when the caller provides a code_challenge + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + # MCP spec requires resource indicator (RFC 8707) + if self.resource_url: + params["resource"] = self.resource_url + + return f"{self.authorize_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_tokens( + self, + code: str, + scopes: list[str], + code_verifier: Optional[str], + ) -> OAuth2Credentials: + data: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if code_verifier: + data["code_verifier"] = code_verifier + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token exchange failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth token response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + provider=self.PROVIDER_NAME, + title=None, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else None + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=None, + scopes=scopes, + metadata={ + "mcp_token_url": self.token_url, + "mcp_resource_url": self.resource_url, + }, + ) + + async def _refresh_tokens( + self, credentials: OAuth2Credentials + ) -> OAuth2Credentials: + if not credentials.refresh_token: + raise ValueError("No refresh token available for MCP OAuth credentials") + + data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token.get_secret_value(), + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token refresh failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth refresh response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + id=credentials.id, + provider=self.PROVIDER_NAME, + title=credentials.title, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else credentials.refresh_token + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=credentials.refresh_token_expires_at, + scopes=credentials.scopes, + metadata=credentials.metadata, + ) + + async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: + if not self.revoke_url: + return False + + try: + data = { + "token": credentials.access_token.get_secret_value(), + "token_type_hint": "access_token", + "client_id": self.client_id, + } + await Requests().post( + self.revoke_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return True + except Exception: + logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True) + return False diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py new file mode 100644 index 0000000000..7818fac9ce --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py @@ -0,0 +1,109 @@ +""" +End-to-end tests against a real public MCP server. + +These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp) +which is publicly accessible without authentication and returns SSE responses. + +Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped +independently of the rest of the test suite (they require network access). +""" + +import json +import os + +import pytest + +from backend.blocks.mcp.client import MCPClient + +# Public MCP server that requires no authentication +OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp" + +# Skip all tests in this module unless RUN_E2E env var is set +pytestmark = pytest.mark.skipif( + not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests" +) + + +class TestRealMCPServer: + """Tests against the live OpenAI docs MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + """Verify we can complete the MCP handshake with a real server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert "serverInfo" in result + assert result["serverInfo"]["name"] == "openai-docs-mcp" + assert "tools" in result.get("capabilities", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + """Verify we can discover tools from a real MCP server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) >= 3 # server has at least 5 tools as of writing + + tool_names = {t.name for t in tools} + # These tools are documented and should be stable + assert "search_openai_docs" in tool_names + assert "list_openai_docs" in tool_names + assert "fetch_openai_doc" in tool_names + + # Verify schema structure + search_tool = next(t for t in tools if t.name == "search_openai_docs") + assert "query" in search_tool.input_schema.get("properties", {}) + assert "query" in search_tool.input_schema.get("required", []) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_list_api_endpoints(self): + """Call the list_api_endpoints tool and verify we get real data.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool("list_api_endpoints", {}) + + assert not result.is_error + assert len(result.content) >= 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert "paths" in data or "urls" in data + # The OpenAI API should have many endpoints + total = data.get("total", len(data.get("paths", []))) + assert total > 50 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_search(self): + """Search for docs and verify we get results.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool( + "search_openai_docs", {"query": "chat completions", "limit": 3} + ) + + assert not result.is_error + assert len(result.content) >= 1 + + @pytest.mark.asyncio(loop_scope="session") + async def test_sse_response_handling(self): + """Verify the client correctly handles SSE responses from a real server. + + This is the key test — our local test server returns JSON, + but real MCP servers typically return SSE. This proves the + SSE parsing works end-to-end. + """ + client = MCPClient(OPENAI_DOCS_MCP_URL) + # initialize() internally calls _send_request which must parse SSE + result = await client.initialize() + + # If we got here without error, SSE parsing works + assert isinstance(result, dict) + assert "protocolVersion" in result + + # Also verify list_tools works (another SSE response) + tools = await client.list_tools() + assert len(tools) > 0 + assert all(hasattr(t, "name") for t in tools) diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_integration.py b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py new file mode 100644 index 0000000000..70658dbaaf --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py @@ -0,0 +1,389 @@ +""" +Integration tests for MCP client and MCPToolBlock against a real HTTP server. + +These tests spin up a local MCP test server and run the full client/block flow +against it — no mocking, real HTTP requests. +""" + +import asyncio +import json +import threading +from unittest.mock import patch + +import pytest +from aiohttp import web +from pydantic import SecretStr + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.test_server import create_test_mcp_app +from backend.data.model import OAuth2Credentials + +MOCK_USER_ID = "test-user-integration" + + +class _MCPTestServer: + """ + Run an MCP test server in a background thread with its own event loop. + This avoids event loop conflicts with pytest-asyncio. + """ + + def __init__(self, auth_token: str | None = None): + self.auth_token = auth_token + self.url: str = "" + self._runner: web.AppRunner | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + self._started = threading.Event() + + def _run(self): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self._start()) + self._started.set() + self._loop.run_forever() + + async def _start(self): + app = create_test_mcp_app(auth_token=self.auth_token) + self._runner = web.AppRunner(app) + await self._runner.setup() + site = web.TCPSite(self._runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] + self.url = f"http://127.0.0.1:{port}/mcp" + + def start(self): + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + if not self._started.wait(timeout=5): + raise RuntimeError("MCP test server failed to start within 5 seconds") + return self + + def stop(self): + if self._loop and self._runner: + asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result( + timeout=5 + ) + self._loop.call_soon_threadsafe(self._loop.stop) + if self._thread: + self._thread.join(timeout=5) + + +@pytest.fixture(scope="module") +def mcp_server(): + """Start a local MCP test server in a background thread.""" + server = _MCPTestServer() + server.start() + yield server.url + server.stop() + + +@pytest.fixture(scope="module") +def mcp_server_with_auth(): + """Start a local MCP test server with auth in a background thread.""" + server = _MCPTestServer(auth_token="test-secret-token") + server.start() + yield server.url, "test-secret-token" + server.stop() + + +@pytest.fixture(autouse=True) +def _allow_localhost(): + """ + Allow 127.0.0.1 through SSRF protection for integration tests. + + The Requests class blocks private IPs by default. We patch the Requests + constructor to always include 127.0.0.1 as a trusted origin so the local + test server is reachable. + """ + from backend.util.request import Requests + + original_init = Requests.__init__ + + def patched_init(self, *args, **kwargs): + trusted = list(kwargs.get("trusted_origins") or []) + trusted.append("http://127.0.0.1") + kwargs["trusted_origins"] = trusted + original_init(self, *args, **kwargs) + + with patch.object(Requests, "__init__", patched_init): + yield + + +def _make_client(url: str, auth_token: str | None = None) -> MCPClient: + """Create an MCPClient for integration tests.""" + return MCPClient(url, auth_token=auth_token) + + +# ── MCPClient integration tests ────────────────────────────────────── + + +class TestMCPClientIntegration: + """Test MCPClient against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self, mcp_server): + client = _make_client(mcp_server) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert result["serverInfo"]["name"] == "test-mcp-server" + assert "tools" in result["capabilities"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) == 3 + + tool_names = {t.name for t in tools} + assert tool_names == {"get_weather", "add_numbers", "echo"} + + # Check get_weather schema + weather = next(t for t in tools if t.name == "get_weather") + assert weather.description == "Get current weather for a city" + assert "city" in weather.input_schema["properties"] + assert weather.input_schema["required"] == ["city"] + + # Check add_numbers schema + add = next(t for t in tools if t.name == "add_numbers") + assert "a" in add.input_schema["properties"] + assert "b" in add.input_schema["properties"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_get_weather(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert data["city"] == "London" + assert data["temperature"] == 22 + assert data["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_add_numbers(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("add_numbers", {"a": 3, "b": 7}) + + assert not result.is_error + data = json.loads(result.content[0]["text"]) + assert data["result"] == 10 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_echo(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("echo", {"message": "Hello MCP!"}) + + assert not result.is_error + assert result.content[0]["text"] == "Hello MCP!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_unknown_tool(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("nonexistent_tool", {}) + + assert result.is_error + assert "Unknown tool" in result.content[0]["text"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_success(self, mcp_server_with_auth): + url, token = mcp_server_with_auth + client = _make_client(url, auth_token=token) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + + tools = await client.list_tools() + assert len(tools) == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_failure(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url, auth_token="wrong-token") + + with pytest.raises(Exception): + await client.initialize() + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_missing(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url) + + with pytest.raises(Exception): + await client.initialize() + + +# ── MCPToolBlock integration tests ─────────────────────────────────── + + +class TestMCPToolBlockIntegration: + """Test MCPToolBlock end-to-end against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_get_weather(self, mcp_server): + """Full flow: discover tools, select one, execute it.""" + # Step 1: Discover tools (simulating what the frontend/API would do) + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + assert len(tools) == 3 + + # Step 2: User selects "get_weather" and we get its schema + weather_tool = next(t for t in tools if t.name == "get_weather") + + # Step 3: Execute the block — no credentials (public server) + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="get_weather", + tool_input_schema=weather_tool.input_schema, + tool_arguments={"city": "Paris"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + result = outputs[0][1] + assert result["city"] == "Paris" + assert result["temperature"] == 22 + assert result["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_add_numbers(self, mcp_server): + """Full flow for add_numbers tool.""" + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + add_tool = next(t for t in tools if t.name == "add_numbers") + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="add_numbers", + tool_input_schema=add_tool.input_schema, + tool_arguments={"a": 42, "b": 58}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1]["result"] == 100 + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_echo_plain_text(self, mcp_server): + """Verify plain text (non-JSON) responses work.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Hello from AutoGPT!"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Hello from AutoGPT!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_unknown_tool_yields_error(self, mcp_server): + """Calling an unknown tool should yield an error output.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="nonexistent_tool", + tool_arguments={}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "error" + assert "returned an error" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_with_auth(self, mcp_server_with_auth): + """Full flow with authentication via credentials kwarg.""" + url, token = mcp_server_with_auth + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=url, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Authenticated!"}, + ) + + # Pass credentials via the standard kwarg (as the executor would) + test_creds = OAuth2Credentials( + id="test-cred", + provider="mcp", + access_token=SecretStr(token), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Authenticated!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_credentials_runs_without_auth(self, mcp_server): + """Block runs without auth when no credentials are provided.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "No auth needed"}, + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=None + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "No auth needed" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py new file mode 100644 index 0000000000..8cb49b0fee --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py @@ -0,0 +1,619 @@ +""" +Tests for MCP client and MCPToolBlock. +""" + +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError +from backend.util.test import execute_block_test + +# ── SSE parsing unit tests ─────────────────────────────────────────── + + +class TestSSEParsing: + """Tests for SSE (text/event-stream) response parsing.""" + + def test_parse_sse_simple(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"tools": []} + assert body["id"] == 1 + + def test_parse_sse_with_notifications(self): + """SSE streams can contain notifications (no id) before the response.""" + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","method":"some/notification"}\n' + "\n" + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"ok": True} + assert body["id"] == 2 + + def test_parse_sse_error_response(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n' + ) + body = MCPClient._parse_sse_response(sse) + assert "error" in body + assert body["error"]["code"] == -32600 + + def test_parse_sse_no_data_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("event: message\n\n") + + def test_parse_sse_empty_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("") + + def test_parse_sse_ignores_non_data_lines(self): + sse = ( + ": comment line\n" + "event: message\n" + "id: 123\n" + 'data: {"jsonrpc":"2.0","result":"ok","id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "ok" + + def test_parse_sse_uses_last_response(self): + """If multiple responses exist, use the last one.""" + sse = ( + 'data: {"jsonrpc":"2.0","result":"first","id":1}\n' + "\n" + 'data: {"jsonrpc":"2.0","result":"second","id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "second" + + +# ── MCPClient unit tests ───────────────────────────────────────────── + + +class TestMCPClient: + """Tests for the MCP HTTP client.""" + + def test_build_headers_without_auth(self): + client = MCPClient("https://mcp.example.com") + headers = client._build_headers() + assert "Authorization" not in headers + assert headers["Content-Type"] == "application/json" + + def test_build_headers_with_auth(self): + client = MCPClient("https://mcp.example.com", auth_token="my-token") + headers = client._build_headers() + assert headers["Authorization"] == "Bearer my-token" + + def test_build_jsonrpc_request(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request("tools/list") + assert req["jsonrpc"] == "2.0" + assert req["method"] == "tools/list" + assert "id" in req + assert "params" not in req + + def test_build_jsonrpc_request_with_params(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request( + "tools/call", {"name": "test", "arguments": {"x": 1}} + ) + assert req["params"] == {"name": "test", "arguments": {"x": 1}} + + def test_request_id_increments(self): + client = MCPClient("https://mcp.example.com") + req1 = client._build_jsonrpc_request("tools/list") + req2 = client._build_jsonrpc_request("tools/list") + assert req2["id"] > req1["id"] + + def test_server_url_trailing_slash_stripped(self): + client = MCPClient("https://mcp.example.com/mcp/") + assert client.server_url == "https://mcp.example.com/mcp" + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_success(self): + client = MCPClient("https://mcp.example.com") + + mock_response = AsyncMock() + mock_response.json.return_value = { + "jsonrpc": "2.0", + "result": {"tools": []}, + "id": 1, + } + + with patch.object(client, "_send_request", return_value={"tools": []}): + result = await client._send_request("tools/list") + assert result == {"tools": []} + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_error(self): + client = MCPClient("https://mcp.example.com") + + async def mock_send(*args, **kwargs): + raise MCPClientError("MCP server error [-32600]: Invalid Request") + + with patch.object(client, "_send_request", side_effect=mock_send): + with pytest.raises(MCPClientError, match="Invalid Request"): + await client._send_request("tools/list") + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + { + "name": "search", + "description": "Search the web", + "inputSchema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + ] + } + + with patch.object(client, "_send_request", return_value=mock_result): + tools = await client.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "get_weather" + assert tools[0].description == "Get current weather for a city" + assert tools[0].input_schema["properties"]["city"]["type"] == "string" + assert tools[1].name == "search" + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_empty(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value={"tools": []}): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_success(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [ + {"type": "text", "text": json.dumps({"temp": 20, "city": "London"})} + ], + "isError": False, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_error(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [{"type": "text", "text": "City not found"}], + "isError": True, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "???"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "test-server", "version": "1.0.0"}, + } + + with ( + patch.object(client, "_send_request", return_value=mock_result) as mock_req, + patch.object(client, "_send_notification") as mock_notif, + ): + result = await client.initialize() + + mock_req.assert_called_once() + mock_notif.assert_called_once_with("notifications/initialized") + assert result["protocolVersion"] == "2025-03-26" + + +# ── MCPToolBlock unit tests ────────────────────────────────────────── + +MOCK_USER_ID = "test-user-123" + + +class TestMCPToolBlock: + """Tests for the MCPToolBlock.""" + + def test_block_instantiation(self): + block = MCPToolBlock() + assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4" + assert block.name == "MCPToolBlock" + + def test_input_schema_has_required_fields(self): + block = MCPToolBlock() + schema = block.input_schema.jsonschema() + props = schema.get("properties", {}) + assert "server_url" in props + assert "selected_tool" in props + assert "tool_arguments" in props + assert "credentials" in props + + def test_output_schema(self): + block = MCPToolBlock() + schema = block.output_schema.jsonschema() + props = schema.get("properties", {}) + assert "result" in props + assert "error" in props + + def test_get_input_schema_with_tool_schema(self): + tool_schema = { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + } + data = {"tool_input_schema": tool_schema} + result = MCPToolBlock.Input.get_input_schema(data) + assert result == tool_schema + + def test_get_input_schema_without_tool_schema(self): + result = MCPToolBlock.Input.get_input_schema({}) + assert result == {} + + def test_get_input_defaults(self): + data = {"tool_arguments": {"city": "London"}} + result = MCPToolBlock.Input.get_input_defaults(data) + assert result == {"city": "London"} + + def test_get_missing_input(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "units": {"type": "string"}, + }, + "required": ["city", "units"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == {"units"} + + def test_get_missing_input_all_present(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == set() + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_mock(self): + """Test the block using the built-in test infrastructure.""" + block = MCPToolBlock() + await execute_block_test(block) + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_server_url(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="", + selected_tool="test", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [("error", "MCP server URL is required")] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_tool(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [ + ("error", "No tool selected. Please select a tool from the dropdown.") + ] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_success(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="get_weather", + tool_input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + tool_arguments={"city": "London"}, + ) + + async def mock_call(*args, **kwargs): + return {"temp": 20, "city": "London"} + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == {"temp": 20, "city": "London"} + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_mcp_error(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="bad_tool", + ) + + async def mock_call(*args, **kwargs): + raise MCPClientError("Tool not found") + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert outputs[0][0] == "error" + assert "Tool not found" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_parses_json_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": '{"temp": 20}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == {"temp": 20} + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_plain_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Hello, world!"}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == "Hello, world!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_multiple_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Part 1"}, + {"type": "text", "text": '{"part": 2}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == ["Part 1", {"part": 2}] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_error_result(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[{"type": "text", "text": "Something went wrong"}], + is_error=True, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + with pytest.raises(MCPClientError, match="returned an error"): + await block._call_mcp_tool("https://mcp.example.com", "test_tool", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_image_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_credentials(self): + """Verify the block uses OAuth2Credentials and passes auth token.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + test_creds = OAuth2Credentials( + id="cred-123", + provider="mcp", + access_token=SecretStr("resolved-token"), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + async for _ in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + pass + + assert captured_tokens == ["resolved-token"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_without_credentials(self): + """Verify the block works without credentials (public server).""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert captured_tokens == [None] + assert outputs == [("result", "ok")] diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py new file mode 100644 index 0000000000..e9a42f68ea --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py @@ -0,0 +1,242 @@ +""" +Tests for MCP OAuth handler. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import SecretStr + +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials + + +def _mock_response(json_data: dict, status: int = 200) -> MagicMock: + """Create a mock Response with synchronous json() (matching Requests.Response).""" + resp = MagicMock() + resp.status = status + resp.ok = 200 <= status < 300 + resp.json.return_value = json_data + return resp + + +class TestMCPOAuthHandler: + """Tests for the MCPOAuthHandler.""" + + def _make_handler(self, **overrides) -> MCPOAuthHandler: + defaults = { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "redirect_uri": "https://app.example.com/callback", + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + } + defaults.update(overrides) + return MCPOAuthHandler(**defaults) + + def test_get_login_url_basic(self): + handler = self._make_handler() + url = handler.get_login_url( + scopes=["read", "write"], + state="random-state-token", + code_challenge="S256-challenge-value", + ) + + assert "https://auth.example.com/authorize?" in url + assert "response_type=code" in url + assert "client_id=test-client-id" in url + assert "state=random-state-token" in url + assert "code_challenge=S256-challenge-value" in url + assert "code_challenge_method=S256" in url + assert "scope=read+write" in url + + def test_get_login_url_with_resource(self): + handler = self._make_handler(resource_url="https://mcp.example.com/mcp") + url = handler.get_login_url( + scopes=[], state="state", code_challenge="challenge" + ) + + assert "resource=https" in url + + def test_get_login_url_without_pkce(self): + handler = self._make_handler() + url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None) + + assert "code_challenge" not in url + assert "code_challenge_method" not in url + + @pytest.mark.asyncio(loop_scope="session") + async def test_exchange_code_for_tokens(self): + handler = self._make_handler() + + resp = _mock_response( + { + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + "token_type": "Bearer", + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + creds = await handler.exchange_code_for_tokens( + code="auth-code", + scopes=["read"], + code_verifier="pkce-verifier", + ) + + assert isinstance(creds, OAuth2Credentials) + assert creds.access_token.get_secret_value() == "new-access-token" + assert creds.refresh_token is not None + assert creds.refresh_token.get_secret_value() == "new-refresh-token" + assert creds.scopes == ["read"] + assert creds.access_token_expires_at is not None + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens(self): + handler = self._make_handler() + + existing_creds = OAuth2Credentials( + id="existing-id", + provider="mcp", + access_token=SecretStr("old-token"), + refresh_token=SecretStr("old-refresh"), + scopes=["read"], + title="test", + ) + + resp = _mock_response( + { + "access_token": "refreshed-token", + "refresh_token": "new-refresh", + "expires_in": 3600, + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + refreshed = await handler._refresh_tokens(existing_creds) + + assert refreshed.id == "existing-id" + assert refreshed.access_token.get_secret_value() == "refreshed-token" + assert refreshed.refresh_token is not None + assert refreshed.refresh_token.get_secret_value() == "new-refresh" + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens_no_refresh_token(self): + handler = self._make_handler() + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=["read"], + title="test", + ) + + with pytest.raises(ValueError, match="No refresh token"): + await handler._refresh_tokens(creds) + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_no_url(self): + handler = self._make_handler(revoke_url=None) + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + result = await handler.revoke_tokens(creds) + assert result is False + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_with_url(self): + handler = self._make_handler(revoke_url="https://auth.example.com/revoke") + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + resp = _mock_response({}, status=200) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + result = await handler.revoke_tokens(creds) + + assert result is True + + +class TestMCPClientDiscovery: + """Tests for MCPClient OAuth metadata discovery.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + metadata = { + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + + resp = _mock_response(metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is not None + assert result["authorization_servers"] == ["https://auth.example.com"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_not_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + resp = _mock_response({}, status=404) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_server_metadata(self): + client = MCPClient("https://mcp.example.com/mcp") + + server_metadata = { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + "code_challenge_methods_supported": ["S256"], + } + + resp = _mock_response(server_metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth_server_metadata( + "https://auth.example.com" + ) + + assert result is not None + assert result["authorization_endpoint"] == "https://auth.example.com/authorize" + assert result["token_endpoint"] == "https://auth.example.com/token" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_server.py b/autogpt_platform/backend/backend/blocks/mcp/test_server.py new file mode 100644 index 0000000000..a6732932bc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_server.py @@ -0,0 +1,162 @@ +""" +Minimal MCP server for integration testing. + +Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST) +with a few sample tools. Runs on localhost with a random available port. +""" + +import json +import logging + +from aiohttp import web + +logger = logging.getLogger(__name__) + +# Sample tools this test server exposes +TEST_TOOLS = [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + }, + }, + "required": ["city"], + }, + }, + { + "name": "add_numbers", + "description": "Add two numbers together", + "inputSchema": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + { + "name": "echo", + "description": "Echo back the input message", + "inputSchema": { + "type": "object", + "properties": { + "message": {"type": "string", "description": "Message to echo"}, + }, + "required": ["message"], + }, + }, +] + + +def _handle_initialize(params: dict) -> dict: + return { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "test-mcp-server", "version": "1.0.0"}, + } + + +def _handle_tools_list(params: dict) -> dict: + return {"tools": TEST_TOOLS} + + +def _handle_tools_call(params: dict) -> dict: + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + if tool_name == "get_weather": + city = arguments.get("city", "Unknown") + return { + "content": [ + { + "type": "text", + "text": json.dumps( + {"city": city, "temperature": 22, "condition": "sunny"} + ), + } + ], + } + + elif tool_name == "add_numbers": + a = arguments.get("a", 0) + b = arguments.get("b", 0) + return { + "content": [{"type": "text", "text": json.dumps({"result": a + b})}], + } + + elif tool_name == "echo": + message = arguments.get("message", "") + return { + "content": [{"type": "text", "text": message}], + } + + else: + return { + "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], + "isError": True, + } + + +HANDLERS = { + "initialize": _handle_initialize, + "tools/list": _handle_tools_list, + "tools/call": _handle_tools_call, +} + + +async def handle_mcp_request(request: web.Request) -> web.Response: + """Handle incoming MCP JSON-RPC 2.0 requests.""" + # Check auth if configured + expected_token = request.app.get("auth_token") + if expected_token: + auth_header = request.headers.get("Authorization", "") + if auth_header != f"Bearer {expected_token}": + return web.json_response( + { + "jsonrpc": "2.0", + "error": {"code": -32001, "message": "Unauthorized"}, + "id": None, + }, + status=401, + ) + + body = await request.json() + + # Handle notifications (no id field) — just acknowledge + if "id" not in body: + return web.Response(status=202) + + method = body.get("method", "") + params = body.get("params", {}) + request_id = body.get("id") + + handler = HANDLERS.get(method) + if not handler: + return web.json_response( + { + "jsonrpc": "2.0", + "error": { + "code": -32601, + "message": f"Method not found: {method}", + }, + "id": request_id, + } + ) + + result = handler(params) + return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id}) + + +def create_test_mcp_app(auth_token: str | None = None) -> web.Application: + """Create an aiohttp app that acts as an MCP server.""" + app = web.Application() + app.router.add_post("/mcp", handle_mcp_request) + if auth_token: + app["auth_token"] = auth_token + return app diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index e9975de3c9..801fe1880d 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -33,6 +33,7 @@ from backend.util import type as type_utils from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError from backend.util.json import SafeJson from backend.util.models import Pagination +from backend.util.request import parse_url from .block import BlockInput from .db import BaseDbModel @@ -449,6 +450,9 @@ class GraphModel(Graph, GraphMeta): continue if ProviderName.HTTP in field.provider: continue + # MCP credentials are intentionally split by server URL + if ProviderName.MCP in field.provider: + continue # If this happens, that means a block implementation probably needs # to be updated. @@ -505,6 +509,18 @@ class GraphModel(Graph, GraphMeta): "required": ["id", "provider", "type"], } + # Add a descriptive display title when URL-based discriminator values + # are present (e.g. "mcp.sentry.dev" instead of just "Mcp") + if ( + field_info.discriminator + and not field_info.discriminator_mapping + and field_info.discriminator_values + ): + hostnames = sorted( + parse_url(str(v)).netloc for v in field_info.discriminator_values + ) + field_schema["display_name"] = ", ".join(hostnames) + # Add other (optional) field info items field_schema.update( field_info.model_dump( @@ -549,8 +565,17 @@ class GraphModel(Graph, GraphMeta): for graph in [self] + self.sub_graphs: for node in graph.nodes: - # Track if this node requires credentials (credentials_optional=False means required) - node_required_map[node.id] = not node.credentials_optional + # A node's credentials are optional if either: + # 1. The node metadata says so (credentials_optional=True), or + # 2. All credential fields on the block have defaults (not required by schema) + block_required = node.block.input_schema.get_required_fields() + creds_required_by_schema = any( + fname in block_required + for fname in node.block.input_schema.get_credentials_fields() + ) + node_required_map[node.id] = ( + not node.credentials_optional and creds_required_by_schema + ) for ( field_name, @@ -776,6 +801,19 @@ class GraphModel(Graph, GraphMeta): "'credentials' and `*_credentials` are reserved" ) + # Check custom block-level validation (e.g., MCP dynamic tool arguments). + # Blocks can override get_missing_input to report additional missing fields + # beyond the standard top-level required fields. + if for_run: + credential_fields = InputSchema.get_credentials_fields() + custom_missing = InputSchema.get_missing_input(node.input_default) + for field_name in custom_missing: + if ( + field_name not in provided_inputs + and field_name not in credential_fields + ): + node_errors[node.id][field_name] = "This field is required" + # Get input schema properties and check dependencies input_fields = InputSchema.model_fields diff --git a/autogpt_platform/backend/backend/data/graph_test.py b/autogpt_platform/backend/backend/data/graph_test.py index 442c8ed4be..3cb6f24b87 100644 --- a/autogpt_platform/backend/backend/data/graph_test.py +++ b/autogpt_platform/backend/backend/data/graph_test.py @@ -462,3 +462,120 @@ def test_node_credentials_optional_with_other_metadata(): assert node.credentials_optional is True assert node.metadata["position"] == {"x": 100, "y": 200} assert node.metadata["customized_name"] == "My Custom Node" + + +# ============================================================================ +# Tests for MCP Credential Deduplication +# ============================================================================ + + +def test_mcp_credential_combine_different_servers(): + """Two MCP credential fields with different server URLs should produce + separate entries when combined (not merged into one).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_sentry = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_linear = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.linear.app/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_sentry, ("node-sentry", "credentials")), + (field_linear, ("node-linear", "credentials")), + ) + + # Should produce 2 separate credential entries + assert len(combined) == 2, ( + f"Expected 2 credential entries for 2 MCP blocks with different servers, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + # Each entry should contain the server hostname in its key + keys = list(combined.keys()) + assert any( + "mcp.sentry.dev" in k for k in keys + ), f"Expected 'mcp.sentry.dev' in one key, got {keys}" + assert any( + "mcp.linear.app" in k for k in keys + ), f"Expected 'mcp.linear.app' in one key, got {keys}" + + +def test_mcp_credential_combine_same_server(): + """Two MCP credential fields with the same server URL should be combined + into one credential entry.""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 credential entry (same server URL) + assert len(combined) == 1, ( + f"Expected 1 credential entry for 2 MCP blocks with same server, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + +def test_mcp_credential_combine_no_discriminator_values(): + """MCP credential fields without discriminator_values should be merged + into a single entry (backwards compat for blocks without server_url set).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 entry (no URL differentiation) + assert len(combined) == 1, ( + f"Expected 1 credential entry for MCP blocks without discriminator_values, " + f"got {len(combined)}: {list(combined.keys())}" + ) diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index c45994f214..ee5d701da6 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -29,6 +29,7 @@ from pydantic import ( GetCoreSchemaHandler, SecretStr, field_serializer, + model_validator, ) from pydantic_core import ( CoreSchema, @@ -503,6 +504,25 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): provider: CP type: CT + @model_validator(mode="before") + @classmethod + def _normalize_legacy_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug. + + Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"`` + instead of the plain value. Old stored credential references may have + ``provider: "ProviderName.MCP"`` instead of ``"mcp"``. + """ + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + @classmethod def allowed_providers(cls) -> tuple[ProviderName, ...] | None: return get_args(cls.model_fields["provider"].annotation) @@ -609,11 +629,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): ] = defaultdict(list) for field, key in fields: - if field.provider == frozenset([ProviderName.HTTP]): - # HTTP host-scoped credentials can have different hosts that reqires different credential sets. - # Group by host extracted from the URL + if ( + field.discriminator + and not field.discriminator_mapping + and field.discriminator_values + ): + # URL-based discrimination (e.g. HTTP host-scoped, MCP server URL): + # Each unique host gets its own credential entry. + provider_prefix = next(iter(field.provider)) + # Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP") + prefix_str = getattr(provider_prefix, "value", str(provider_prefix)) providers = frozenset( - [cast(CP, "http")] + [cast(CP, prefix_str)] + [ cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index a235e72ef1..d961d0e293 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -20,6 +20,7 @@ from backend.blocks import get_block from backend.blocks._base import BlockSchema from backend.blocks.agent import AgentExecutorBlock from backend.blocks.io import AgentOutputBlock +from backend.blocks.mcp.block import MCPToolBlock from backend.data import redis_client as redis from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry from backend.data.credit import UsageTransactionMetadata @@ -228,6 +229,18 @@ async def execute_node( _input_data.nodes_input_masks = nodes_input_masks _input_data.user_id = user_id input_data = _input_data.model_dump() + elif isinstance(node_block, MCPToolBlock): + _mcp_data = MCPToolBlock.Input(**node.input_default) + # Dynamic tool fields are flattened to top-level by validate_exec + # (via get_input_defaults). Collect them back into tool_arguments. + tool_schema = _mcp_data.tool_input_schema + tool_props = set(tool_schema.get("properties", {}).keys()) + merged_args = {**_mcp_data.tool_arguments} + for key in tool_props: + if key in input_data: + merged_args[key] = input_data[key] + _mcp_data.tool_arguments = merged_args + input_data = _mcp_data.model_dump() data.inputs = input_data # Execute the node @@ -264,8 +277,34 @@ async def execute_node( # Handle regular credentials fields for field_name, input_type in input_model.get_credentials_fields().items(): - credentials_meta = input_type(**input_data[field_name]) - credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id) + field_value = input_data.get(field_name) + if not field_value or ( + isinstance(field_value, dict) and not field_value.get("id") + ): + # No credentials configured — nullify so JSON schema validation + # doesn't choke on the empty default `{}`. + input_data[field_name] = None + continue # Block runs without credentials + + credentials_meta = input_type(**field_value) + # Write normalized values back so JSON schema validation also passes + # (model_validator may have fixed legacy formats like "ProviderName.MCP") + input_data[field_name] = credentials_meta.model_dump(mode="json") + try: + credentials, lock = await creds_manager.acquire( + user_id, credentials_meta.id + ) + except ValueError: + # Credential was deleted or doesn't exist. + # If the field has a default, run without credentials. + if input_model.model_fields[field_name].default is not None: + log_metadata.warning( + f"Credentials #{credentials_meta.id} not found, " + "running without (field has default)" + ) + input_data[field_name] = None + continue + raise creds_locks.append(lock) extra_exec_kwargs[field_name] = credentials diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index bb5da1e527..2b9a454061 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -260,7 +260,13 @@ async def _validate_node_input_credentials( # Track if any credential field is missing for this node has_missing_credentials = False + # A credential field is optional if the node metadata says so, or if + # the block schema declares a default for the field. + required_fields = block.input_schema.get_required_fields() + is_creds_optional = node.credentials_optional + for field_name, credentials_meta_type in credentials_fields.items(): + field_is_optional = is_creds_optional or field_name not in required_fields try: # Check nodes_input_masks first, then input_default field_value = None @@ -273,7 +279,7 @@ async def _validate_node_input_credentials( elif field_name in node.input_default: # For optional credentials, don't use input_default - treat as missing # This prevents stale credential IDs from failing validation - if node.credentials_optional: + if field_is_optional: field_value = None else: field_value = node.input_default[field_name] @@ -283,8 +289,8 @@ async def _validate_node_input_credentials( isinstance(field_value, dict) and not field_value.get("id") ): has_missing_credentials = True - # If node has credentials_optional flag, mark for skipping instead of error - if node.credentials_optional: + # If credential field is optional, skip instead of error + if field_is_optional: continue # Don't add error, will be marked for skip after loop else: credential_errors[node.id][ @@ -334,16 +340,16 @@ async def _validate_node_input_credentials( ] = "Invalid credentials: type/provider mismatch" continue - # If node has optional credentials and any are missing, mark for skipping - # But only if there are no other errors for this node + # If node has optional credentials and any are missing, allow running without. + # The executor will pass credentials=None to the block's run(). if ( has_missing_credentials - and node.credentials_optional + and is_creds_optional and node.id not in credential_errors ): - nodes_to_skip.add(node.id) logger.info( - f"Node #{node.id} will be skipped: optional credentials not configured" + f"Node #{node.id}: optional credentials not configured, " + "running without" ) return credential_errors, nodes_to_skip diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index db33249583..069086a6fd 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -495,6 +495,7 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph @@ -508,8 +509,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( nodes_input_masks=None, ) - # Node should be in nodes_to_skip, not in errors - assert mock_node.id in nodes_to_skip + # Node should NOT be in nodes_to_skip (runs without credentials) and not in errors + assert mock_node.id not in nodes_to_skip assert mock_node.id not in errors @@ -535,6 +536,7 @@ async def test_validate_node_input_credentials_required_missing_creds_error( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph diff --git a/autogpt_platform/backend/backend/integrations/credentials_store.py b/autogpt_platform/backend/backend/integrations/credentials_store.py index 384405b0c7..3e79a6c047 100644 --- a/autogpt_platform/backend/backend/integrations/credentials_store.py +++ b/autogpt_platform/backend/backend/integrations/credentials_store.py @@ -22,6 +22,27 @@ from backend.util.settings import Settings settings = Settings() + +def provider_matches(stored: str, expected: str) -> bool: + """Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug. + + On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"`` + instead of ``"mcp"``. OAuth states persisted with the buggy format need + to match when ``expected`` is the canonical value (e.g. ``"mcp"``). + """ + if stored == expected: + return True + if stored.startswith("ProviderName."): + member = stored.removeprefix("ProviderName.") + from backend.integrations.providers import ProviderName + + try: + return ProviderName[member].value == expected + except KeyError: + pass + return False + + # This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached ollama_credentials = APIKeyCredentials( id="744fdc56-071a-4761-b5a5-0af0ce10a2b5", @@ -389,7 +410,7 @@ class IntegrationCredentialsStore: self, user_id: str, provider: str ) -> list[Credentials]: credentials = await self.get_all_creds(user_id) - return [c for c in credentials if c.provider == provider] + return [c for c in credentials if provider_matches(c.provider, provider)] async def get_authorized_providers(self, user_id: str) -> list[str]: credentials = await self.get_all_creds(user_id) @@ -485,17 +506,6 @@ class IntegrationCredentialsStore: async with self.edit_user_integrations(user_id) as user_integrations: user_integrations.oauth_states.append(state) - async with await self.locked_user_integrations(user_id): - - user_integrations = await self._get_user_integrations(user_id) - oauth_states = user_integrations.oauth_states - oauth_states.append(state) - user_integrations.oauth_states = oauth_states - - await self.db_manager.update_user_integrations( - user_id=user_id, data=user_integrations - ) - return token, code_challenge def _generate_code_challenge(self) -> tuple[str, str]: @@ -521,7 +531,7 @@ class IntegrationCredentialsStore: state for state in oauth_states if secrets.compare_digest(state.token, token) - and state.provider == provider + and provider_matches(state.provider, provider) and state.expires_at > now.timestamp() ), None, diff --git a/autogpt_platform/backend/backend/integrations/creds_manager.py b/autogpt_platform/backend/backend/integrations/creds_manager.py index f2b6a9da4f..5634dd73b6 100644 --- a/autogpt_platform/backend/backend/integrations/creds_manager.py +++ b/autogpt_platform/backend/backend/integrations/creds_manager.py @@ -9,7 +9,10 @@ from redis.asyncio.lock import Lock as AsyncRedisLock from backend.data.model import Credentials, OAuth2Credentials from backend.data.redis_client import get_redis_async -from backend.integrations.credentials_store import IntegrationCredentialsStore +from backend.integrations.credentials_store import ( + IntegrationCredentialsStore, + provider_matches, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.util.exceptions import MissingConfigError @@ -137,7 +140,10 @@ class IntegrationCredentialsManager: self, user_id: str, credentials: OAuth2Credentials, lock: bool = True ) -> OAuth2Credentials: async with self._locked(user_id, credentials.id, "refresh"): - oauth_handler = await _get_provider_oauth_handler(credentials.provider) + if provider_matches(credentials.provider, ProviderName.MCP.value): + oauth_handler = create_mcp_oauth_handler(credentials) + else: + oauth_handler = await _get_provider_oauth_handler(credentials.provider) if oauth_handler.needs_refresh(credentials): logger.debug( f"Refreshing '{credentials.provider}' " @@ -236,3 +242,31 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl client_secret=client_secret, redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback", ) + + +def create_mcp_oauth_handler( + credentials: OAuth2Credentials, +) -> "BaseOAuthHandler": + """Create an MCPOAuthHandler from credential metadata for token refresh. + + MCP OAuth handlers have dynamic endpoints discovered per-server, so they + can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler + is reconstructed from metadata stored on the credential during initial auth. + """ + from backend.blocks.mcp.oauth import MCPOAuthHandler + + meta = credentials.metadata or {} + token_url = meta.get("mcp_token_url", "") + if not token_url: + raise ValueError( + f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; " + "cannot refresh tokens" + ) + return MCPOAuthHandler( + client_id=meta.get("mcp_client_id", ""), + client_secret=meta.get("mcp_client_secret", ""), + redirect_uri="", # Not needed for token refresh + authorize_url="", # Not needed for token refresh + token_url=token_url, + resource_url=meta.get("mcp_resource_url"), + ) diff --git a/autogpt_platform/backend/backend/integrations/providers.py b/autogpt_platform/backend/backend/integrations/providers.py index 8a0d6fd183..a462cd787f 100644 --- a/autogpt_platform/backend/backend/integrations/providers.py +++ b/autogpt_platform/backend/backend/integrations/providers.py @@ -30,6 +30,7 @@ class ProviderName(str, Enum): IDEOGRAM = "ideogram" JINA = "jina" LLAMA_API = "llama_api" + MCP = "mcp" MEDIUM = "medium" MEM0 = "mem0" NOTION = "notion" diff --git a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py index 99eee404b9..8fdbe10383 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py @@ -51,6 +51,21 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str): if ( creds_meta := new_node.input_default.get(creds_field_name) ) and not await get_credentials(creds_meta["id"]): + # If the credential field is optional (has a default in the + # schema, or node metadata marks it optional), clear the stale + # reference instead of blocking the save. + creds_field_optional = ( + new_node.credentials_optional + or creds_field_name not in block_input_schema.get_required_fields() + ) + if creds_field_optional: + new_node.input_default[creds_field_name] = {} + logger.warning( + f"Node #{new_node.id}: cleared stale optional " + f"credentials #{creds_meta['id']} for " + f"'{creds_field_name}'" + ) + continue raise ValueError( f"Node #{new_node.id} input '{creds_field_name}' updated with " f"non-existent credentials #{creds_meta['id']}" diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index fbd3573112..4eadc41333 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -38,6 +38,7 @@ class Flag(str, Enum): AGENT_ACTIVITY = "agent-activity" ENABLE_PLATFORM_PAYMENT = "enable-platform-payment" CHAT = "chat" + COPILOT_SDK = "copilot-sdk" def is_configured() -> bool: diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 95e5ee32f7..9470909dfc 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver): def __init__(self, ssl_hostname: str, ip_addresses: list[str]): self.ssl_hostname = ssl_hostname self.ip_addresses = ip_addresses - self._default = aiohttp.AsyncResolver() + self._default = aiohttp.ThreadedResolver() async def resolve(self, host, port=0, family=socket.AF_INET): if host == self.ssl_hostname: @@ -467,7 +467,7 @@ class Requests: resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses) ssl_context = ssl.create_default_context() connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context) - session_kwargs = {} + session_kwargs: dict = {} if connector: session_kwargs["connector"] = connector diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 48dadb88f1..c5cca87b6e 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -662,6 +662,17 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): mem0_api_key: str = Field(default="", description="Mem0 API key") elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key") + linear_api_key: str = Field( + default="", description="Linear API key for system-level operations" + ) + linear_feature_request_project_id: str = Field( + default="", + description="Linear project ID where feature requests are tracked", + ) + linear_feature_request_team_id: str = Field( + default="", + description="Linear team ID used when creating feature request issues", + ) linear_client_id: str = Field(default="", description="Linear client ID") linear_client_secret: str = Field(default="", description="Linear client secret") diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index d71cca7865..8062457a70 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -897,6 +897,29 @@ files = [ {file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"}, ] +[[package]] +name = "claude-agent-sdk" +version = "0.1.35" +description = "Python SDK for Claude Code" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"}, + {file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"}, + {file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"}, +] + +[package.dependencies] +anyio = ">=4.0.0" +mcp = ">=0.1.0" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"] + [[package]] name = "cleo" version = "2.1.0" @@ -2593,6 +2616,18 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.3" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, + {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, +] + [[package]] name = "huggingface-hub" version = "1.4.1" @@ -3310,6 +3345,39 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mcp" +version = "1.26.0" +description = "Model Context Protocol SDK" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"}, + {file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"}, +] + +[package.dependencies] +anyio = ">=4.5" +httpx = ">=0.27.1" +httpx-sse = ">=0.4" +jsonschema = ">=4.20.0" +pydantic = ">=2.11.0,<3.0.0" +pydantic-settings = ">=2.5.2" +pyjwt = {version = ">=2.10.1", extras = ["crypto"]} +python-multipart = ">=0.0.9" +pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""} +sse-starlette = ">=1.6.1" +starlette = ">=0.27" +typing-extensions = ">=4.9.0" +typing-inspection = ">=0.4.1" +uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""} + +[package.extras] +cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"] +rich = ["rich (>=13.9.4)"] +ws = ["websockets (>=15.0.1)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -5994,7 +6062,7 @@ description = "Python for Window Extensions" optional = false python-versions = "*" groups = ["main"] -markers = "platform_system == \"Windows\"" +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"}, {file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"}, @@ -6974,6 +7042,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sse-starlette" +version = "3.2.0" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"}, + {file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"}, +] + +[package.dependencies] +anyio = ">=4.7.0" +starlette = ">=0.49.1" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + [[package]] name = "stagehand" version = "0.5.9" @@ -8440,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f" +content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index 32dfc547bc..7a112e75ca 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -16,6 +16,7 @@ anthropic = "^0.79.0" apscheduler = "^3.11.1" autogpt-libs = { path = "../autogpt_libs", develop = true } bleach = { extras = ["css"], version = "^6.2.0" } +claude-agent-sdk = "^0.1.0" click = "^8.2.0" cryptography = "^46.0" discord-py = "^2.5.2" diff --git a/autogpt_platform/backend/test/blocks/test_jina_extract_website.py b/autogpt_platform/backend/test/blocks/test_jina_extract_website.py new file mode 100644 index 0000000000..335c43f966 --- /dev/null +++ b/autogpt_platform/backend/test/blocks/test_jina_extract_website.py @@ -0,0 +1,66 @@ +from typing import cast + +import pytest + +from backend.blocks.jina._auth import ( + TEST_CREDENTIALS, + TEST_CREDENTIALS_INPUT, + JinaCredentialsInput, +) +from backend.blocks.jina.search import ExtractWebsiteContentBlock +from backend.util.request import HTTPClientError + + +@pytest.mark.asyncio +async def test_extract_website_content_returns_content(monkeypatch): + block = ExtractWebsiteContentBlock() + input_data = block.Input( + url="https://example.com", + credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT), + raw_content=True, + ) + + async def fake_get_request(url, json=False, headers=None): + assert url == "https://example.com" + assert headers == {} + return "page content" + + monkeypatch.setattr(block, "get_request", fake_get_request) + + results = [ + output + async for output in block.run( + input_data=input_data, credentials=TEST_CREDENTIALS + ) + ] + + assert ("content", "page content") in results + assert all(key != "error" for key, _ in results) + + +@pytest.mark.asyncio +async def test_extract_website_content_handles_http_error(monkeypatch): + block = ExtractWebsiteContentBlock() + input_data = block.Input( + url="https://example.com", + credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT), + raw_content=False, + ) + + async def fake_get_request(url, json=False, headers=None): + raise HTTPClientError("HTTP 400 Error: Bad Request", 400) + + monkeypatch.setattr(block, "get_request", fake_get_request) + + results = [ + output + async for output in block.run( + input_data=input_data, credentials=TEST_CREDENTIALS + ) + ] + + assert ("content", "page content") not in results + error_messages = [value for key, value in results if key == "error"] + assert error_messages + assert "Client error (400)" in error_messages[0] + assert "https://example.com" in error_messages[0] diff --git a/autogpt_platform/backend/test/blocks/test_list_concatenation.py b/autogpt_platform/backend/test/blocks/test_list_concatenation.py new file mode 100644 index 0000000000..8cea3b60f7 --- /dev/null +++ b/autogpt_platform/backend/test/blocks/test_list_concatenation.py @@ -0,0 +1,1276 @@ +""" +Comprehensive test suite for list concatenation and manipulation blocks. + +Tests cover: +- ConcatenateListsBlock: basic concatenation, deduplication, None removal +- FlattenListBlock: nested list flattening with depth control +- InterleaveListsBlock: round-robin interleaving of multiple lists +- ZipListsBlock: zipping lists with truncation and padding +- ListDifferenceBlock: computing list differences (regular and symmetric) +- ListIntersectionBlock: finding common elements between lists +- Helper utility functions: validation, flattening, deduplication, etc. +""" + +import pytest + +from backend.blocks.data_manipulation import ( + _MAX_FLATTEN_DEPTH, + ConcatenateListsBlock, + FlattenListBlock, + InterleaveListsBlock, + ListDifferenceBlock, + ListIntersectionBlock, + ZipListsBlock, + _compute_nesting_depth, + _concatenate_lists_simple, + _deduplicate_list, + _filter_none_values, + _flatten_nested_list, + _interleave_lists, + _make_hashable, + _validate_all_lists, + _validate_list_input, +) +from backend.util.test import execute_block_test + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + + +class TestValidateListInput: + """Tests for the _validate_list_input helper.""" + + def test_valid_list_returns_none(self): + assert _validate_list_input([1, 2, 3], 0) is None + + def test_empty_list_returns_none(self): + assert _validate_list_input([], 0) is None + + def test_none_returns_none(self): + assert _validate_list_input(None, 0) is None + + def test_string_returns_error(self): + result = _validate_list_input("hello", 0) + assert result is not None + assert "str" in result + assert "index 0" in result + + def test_integer_returns_error(self): + result = _validate_list_input(42, 1) + assert result is not None + assert "int" in result + assert "index 1" in result + + def test_dict_returns_error(self): + result = _validate_list_input({"a": 1}, 2) + assert result is not None + assert "dict" in result + assert "index 2" in result + + def test_tuple_returns_error(self): + result = _validate_list_input((1, 2), 3) + assert result is not None + assert "tuple" in result + + def test_boolean_returns_error(self): + result = _validate_list_input(True, 0) + assert result is not None + assert "bool" in result + + def test_float_returns_error(self): + result = _validate_list_input(3.14, 0) + assert result is not None + assert "float" in result + + +class TestValidateAllLists: + """Tests for the _validate_all_lists helper.""" + + def test_all_valid_lists(self): + assert _validate_all_lists([[1], [2], [3]]) is None + + def test_empty_outer_list(self): + assert _validate_all_lists([]) is None + + def test_mixed_valid_and_none(self): + # None is skipped, so this should pass + assert _validate_all_lists([[1], None, [3]]) is None + + def test_invalid_item_returns_error(self): + result = _validate_all_lists([[1], "bad", [3]]) + assert result is not None + assert "index 1" in result + + def test_first_invalid_is_returned(self): + result = _validate_all_lists(["first_bad", "second_bad"]) + assert result is not None + assert "index 0" in result + + def test_all_none_passes(self): + assert _validate_all_lists([None, None, None]) is None + + +class TestConcatenateListsSimple: + """Tests for the _concatenate_lists_simple helper.""" + + def test_basic_concatenation(self): + assert _concatenate_lists_simple([[1, 2], [3, 4]]) == [1, 2, 3, 4] + + def test_empty_lists(self): + assert _concatenate_lists_simple([[], []]) == [] + + def test_single_list(self): + assert _concatenate_lists_simple([[1, 2, 3]]) == [1, 2, 3] + + def test_no_lists(self): + assert _concatenate_lists_simple([]) == [] + + def test_skip_none_values(self): + assert _concatenate_lists_simple([[1, 2], None, [3, 4]]) == [1, 2, 3, 4] # type: ignore[arg-type] + + def test_mixed_types(self): + result = _concatenate_lists_simple([[1, "a"], [True, 3.14]]) + assert result == [1, "a", True, 3.14] + + def test_nested_lists_preserved(self): + result = _concatenate_lists_simple([[[1, 2]], [[3, 4]]]) + assert result == [[1, 2], [3, 4]] + + def test_large_number_of_lists(self): + lists = [[i] for i in range(100)] + result = _concatenate_lists_simple(lists) + assert result == list(range(100)) + + +class TestFlattenNestedList: + """Tests for the _flatten_nested_list helper.""" + + def test_already_flat(self): + assert _flatten_nested_list([1, 2, 3]) == [1, 2, 3] + + def test_one_level_nesting(self): + assert _flatten_nested_list([[1, 2], [3, 4]]) == [1, 2, 3, 4] + + def test_deep_nesting(self): + assert _flatten_nested_list([1, [2, [3, [4, [5]]]]]) == [1, 2, 3, 4, 5] + + def test_empty_list(self): + assert _flatten_nested_list([]) == [] + + def test_mixed_nesting(self): + assert _flatten_nested_list([1, [2, 3], 4, [5, [6]]]) == [1, 2, 3, 4, 5, 6] + + def test_max_depth_zero(self): + # max_depth=0 means no flattening at all + result = _flatten_nested_list([[1, 2], [3, 4]], max_depth=0) + assert result == [[1, 2], [3, 4]] + + def test_max_depth_one(self): + result = _flatten_nested_list([[1, [2, 3]], [4, [5]]], max_depth=1) + assert result == [1, [2, 3], 4, [5]] + + def test_max_depth_two(self): + result = _flatten_nested_list([[[1, 2], [3]], [[4, [5]]]], max_depth=2) + assert result == [1, 2, 3, 4, [5]] + + def test_unlimited_depth(self): + deeply_nested = [[[[[[[1]]]]]]] + assert _flatten_nested_list(deeply_nested, max_depth=-1) == [1] + + def test_preserves_non_list_iterables(self): + result = _flatten_nested_list(["hello", [1, 2]]) + assert result == ["hello", 1, 2] + + def test_preserves_dicts(self): + result = _flatten_nested_list([{"a": 1}, [{"b": 2}]]) + assert result == [{"a": 1}, {"b": 2}] + + def test_excessive_depth_raises_recursion_error(self): + """Deeply nested lists beyond 1000 levels should raise RecursionError.""" + # Build a list nested 1100 levels deep + nested = [42] + for _ in range(1100): + nested = [nested] + with pytest.raises(RecursionError, match="maximum.*depth"): + _flatten_nested_list(nested, max_depth=-1) + + +class TestDeduplicateList: + """Tests for the _deduplicate_list helper.""" + + def test_no_duplicates(self): + assert _deduplicate_list([1, 2, 3]) == [1, 2, 3] + + def test_with_duplicates(self): + assert _deduplicate_list([1, 2, 2, 3, 3, 3]) == [1, 2, 3] + + def test_all_duplicates(self): + assert _deduplicate_list([1, 1, 1]) == [1] + + def test_empty_list(self): + assert _deduplicate_list([]) == [] + + def test_preserves_order(self): + result = _deduplicate_list([3, 1, 2, 1, 3]) + assert result == [3, 1, 2] + + def test_string_duplicates(self): + assert _deduplicate_list(["a", "b", "a", "c"]) == ["a", "b", "c"] + + def test_mixed_types(self): + result = _deduplicate_list([1, "1", 1, "1"]) + assert result == [1, "1"] + + def test_dict_duplicates(self): + result = _deduplicate_list([{"a": 1}, {"a": 1}, {"b": 2}]) + assert result == [{"a": 1}, {"b": 2}] + + def test_list_duplicates(self): + result = _deduplicate_list([[1, 2], [1, 2], [3, 4]]) + assert result == [[1, 2], [3, 4]] + + def test_none_duplicates(self): + result = _deduplicate_list([None, 1, None, 2]) + assert result == [None, 1, 2] + + def test_single_element(self): + assert _deduplicate_list([42]) == [42] + + +class TestMakeHashable: + """Tests for the _make_hashable helper.""" + + def test_integer(self): + assert _make_hashable(42) == 42 + + def test_string(self): + assert _make_hashable("hello") == "hello" + + def test_none(self): + assert _make_hashable(None) is None + + def test_dict_returns_tuple(self): + result = _make_hashable({"a": 1}) + assert isinstance(result, tuple) + # Should be hashable + hash(result) + + def test_list_returns_tuple(self): + result = _make_hashable([1, 2, 3]) + assert result == (1, 2, 3) + + def test_same_dict_same_hash(self): + assert _make_hashable({"a": 1, "b": 2}) == _make_hashable({"a": 1, "b": 2}) + + def test_different_dict_different_hash(self): + assert _make_hashable({"a": 1}) != _make_hashable({"a": 2}) + + def test_dict_key_order_independent(self): + """Dicts with same keys in different insertion order produce same result.""" + d1 = {"b": 2, "a": 1} + d2 = {"a": 1, "b": 2} + assert _make_hashable(d1) == _make_hashable(d2) + + def test_tuple_hashable(self): + result = _make_hashable((1, 2, 3)) + assert result == (1, 2, 3) + hash(result) + + def test_boolean(self): + result = _make_hashable(True) + assert result is True + + def test_float(self): + result = _make_hashable(3.14) + assert result == 3.14 + + +class TestFilterNoneValues: + """Tests for the _filter_none_values helper.""" + + def test_removes_none(self): + assert _filter_none_values([1, None, 2, None, 3]) == [1, 2, 3] + + def test_no_none(self): + assert _filter_none_values([1, 2, 3]) == [1, 2, 3] + + def test_all_none(self): + assert _filter_none_values([None, None, None]) == [] + + def test_empty_list(self): + assert _filter_none_values([]) == [] + + def test_preserves_falsy_values(self): + assert _filter_none_values([0, False, "", None, []]) == [0, False, "", []] + + +class TestComputeNestingDepth: + """Tests for the _compute_nesting_depth helper.""" + + def test_flat_list(self): + assert _compute_nesting_depth([1, 2, 3]) == 1 + + def test_one_level(self): + assert _compute_nesting_depth([[1, 2], [3, 4]]) == 2 + + def test_deep_nesting(self): + assert _compute_nesting_depth([[[[]]]]) == 4 + + def test_mixed_depth(self): + depth = _compute_nesting_depth([1, [2, [3]]]) + assert depth == 3 + + def test_empty_list(self): + assert _compute_nesting_depth([]) == 1 + + def test_non_list(self): + assert _compute_nesting_depth(42) == 0 + + def test_string_not_recursed(self): + # Strings should not be treated as nested lists + assert _compute_nesting_depth(["hello"]) == 1 + + +class TestInterleaveListsHelper: + """Tests for the _interleave_lists helper.""" + + def test_equal_length_lists(self): + result = _interleave_lists([[1, 2, 3], ["a", "b", "c"]]) + assert result == [1, "a", 2, "b", 3, "c"] + + def test_unequal_length_lists(self): + result = _interleave_lists([[1, 2, 3], ["a"]]) + assert result == [1, "a", 2, 3] + + def test_empty_input(self): + assert _interleave_lists([]) == [] + + def test_single_list(self): + assert _interleave_lists([[1, 2, 3]]) == [1, 2, 3] + + def test_three_lists(self): + result = _interleave_lists([[1], [2], [3]]) + assert result == [1, 2, 3] + + def test_with_none_list(self): + result = _interleave_lists([[1, 2], None, [3, 4]]) # type: ignore[arg-type] + assert result == [1, 3, 2, 4] + + def test_all_empty_lists(self): + assert _interleave_lists([[], [], []]) == [] + + def test_all_none_lists(self): + """All-None inputs should return empty list, not crash.""" + assert _interleave_lists([None, None, None]) == [] # type: ignore[arg-type] + + +class TestComputeNestingDepthEdgeCases: + """Tests for _compute_nesting_depth with deeply nested input.""" + + def test_deeply_nested_does_not_crash(self): + """Deeply nested lists beyond 1000 levels should not raise RecursionError.""" + nested = [42] + for _ in range(1100): + nested = [nested] + # Should return a depth value without crashing + depth = _compute_nesting_depth(nested) + assert depth >= _MAX_FLATTEN_DEPTH + + +class TestMakeHashableMixedKeys: + """Tests for _make_hashable with mixed-type dict keys.""" + + def test_mixed_type_dict_keys(self): + """Dicts with mixed-type keys (int and str) should not crash sorted().""" + d = {1: "one", "two": 2} + result = _make_hashable(d) + assert isinstance(result, tuple) + hash(result) # Should be hashable without error + + def test_mixed_type_keys_deterministic(self): + """Same dict with mixed keys produces same result.""" + d1 = {1: "a", "b": 2} + d2 = {1: "a", "b": 2} + assert _make_hashable(d1) == _make_hashable(d2) + + +class TestZipListsNoneHandling: + """Tests for ZipListsBlock with None values in input.""" + + def setup_method(self): + self.block = ZipListsBlock() + + def test_zip_truncate_with_none(self): + """_zip_truncate should handle None values in input lists.""" + result = self.block._zip_truncate([[1, 2], None, [3, 4]]) # type: ignore[arg-type] + assert result == [[1, 3], [2, 4]] + + def test_zip_pad_with_none(self): + """_zip_pad should handle None values in input lists.""" + result = self.block._zip_pad([[1, 2, 3], None, ["a"]], fill_value="X") # type: ignore[arg-type] + assert result == [[1, "a"], [2, "X"], [3, "X"]] + + def test_zip_truncate_all_none(self): + """All-None inputs should return empty list.""" + result = self.block._zip_truncate([None, None]) # type: ignore[arg-type] + assert result == [] + + def test_zip_pad_all_none(self): + """All-None inputs should return empty list.""" + result = self.block._zip_pad([None, None], fill_value=0) # type: ignore[arg-type] + assert result == [] + + +# ============================================================================= +# Block Built-in Tests (using test_input/test_output) +# ============================================================================= + + +class TestConcatenateListsBlockBuiltin: + """Run the built-in test_input/test_output tests for ConcatenateListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ConcatenateListsBlock() + await execute_block_test(block) + + +class TestFlattenListBlockBuiltin: + """Run the built-in test_input/test_output tests for FlattenListBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = FlattenListBlock() + await execute_block_test(block) + + +class TestInterleaveListsBlockBuiltin: + """Run the built-in test_input/test_output tests for InterleaveListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = InterleaveListsBlock() + await execute_block_test(block) + + +class TestZipListsBlockBuiltin: + """Run the built-in test_input/test_output tests for ZipListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ZipListsBlock() + await execute_block_test(block) + + +class TestListDifferenceBlockBuiltin: + """Run the built-in test_input/test_output tests for ListDifferenceBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ListDifferenceBlock() + await execute_block_test(block) + + +class TestListIntersectionBlockBuiltin: + """Run the built-in test_input/test_output tests for ListIntersectionBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ListIntersectionBlock() + await execute_block_test(block) + + +# ============================================================================= +# ConcatenateListsBlock Manual Tests +# ============================================================================= + + +class TestConcatenateListsBlockManual: + """Manual test cases for ConcatenateListsBlock edge cases.""" + + def setup_method(self): + self.block = ConcatenateListsBlock() + + @pytest.mark.asyncio + async def test_two_lists(self): + """Test basic two-list concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1, 2], [3, 4]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4] + assert results["length"] == 4 + + @pytest.mark.asyncio + async def test_three_lists(self): + """Test three-list concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1], [2], [3]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_five_lists(self): + """Test concatenation of five lists.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1], [2], [3], [4], [5]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4, 5] + assert results["length"] == 5 + + @pytest.mark.asyncio + async def test_empty_lists_only(self): + """Test concatenation of only empty lists.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[], [], []]) + ): + results[name] = value + assert results["concatenated_list"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_mixed_types_in_lists(self): + """Test concatenation with mixed types.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, "a"], [True, 3.14], [None, {"key": "val"}]] + ) + ): + results[name] = value + assert results["concatenated_list"] == [ + 1, + "a", + True, + 3.14, + None, + {"key": "val"}, + ] + + @pytest.mark.asyncio + async def test_deduplication_enabled(self): + """Test deduplication removes duplicates.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, 2, 3], [2, 3, 4], [3, 4, 5]], + deduplicate=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_deduplication_preserves_order(self): + """Test that deduplication preserves first-occurrence order.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[3, 1, 2], [2, 4, 1]], + deduplicate=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [3, 1, 2, 4] + + @pytest.mark.asyncio + async def test_remove_none_enabled(self): + """Test None removal from concatenated results.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, None], [None, 2], [3, None]], + remove_none=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_dedup_and_remove_none_combined(self): + """Test both deduplication and None removal together.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, None, 2], [2, None, 3]], + deduplicate=True, + remove_none=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_nested_lists_preserved(self): + """Test that nested lists are not flattened during concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[[1, 2]], [[3, 4]]]) + ): + results[name] = value + assert results["concatenated_list"] == [[1, 2], [3, 4]] + + @pytest.mark.asyncio + async def test_large_lists(self): + """Test concatenation of large lists.""" + list_a = list(range(1000)) + list_b = list(range(1000, 2000)) + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[list_a, list_b]) + ): + results[name] = value + assert results["concatenated_list"] == list(range(2000)) + assert results["length"] == 2000 + + @pytest.mark.asyncio + async def test_single_list_input(self): + """Test concatenation with a single list.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1, 2, 3]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + @pytest.mark.asyncio + async def test_block_category(self): + """Test that the block has the correct category.""" + from backend.blocks._base import BlockCategory + + assert BlockCategory.BASIC in self.block.categories + + +# ============================================================================= +# FlattenListBlock Manual Tests +# ============================================================================= + + +class TestFlattenListBlockManual: + """Manual test cases for FlattenListBlock.""" + + def setup_method(self): + self.block = FlattenListBlock() + + @pytest.mark.asyncio + async def test_simple_flatten(self): + """Test flattening a simple nested list.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[[1, 2], [3, 4]]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4] + assert results["length"] == 4 + + @pytest.mark.asyncio + async def test_deeply_nested(self): + """Test flattening a deeply nested structure.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, [2, [3, [4, [5]]]]]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_partial_flatten(self): + """Test flattening with max_depth=1.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input( + nested_list=[[1, [2, 3]], [4, [5]]], + max_depth=1, + ) + ): + results[name] = value + assert results["flattened_list"] == [1, [2, 3], 4, [5]] + + @pytest.mark.asyncio + async def test_already_flat_list(self): + """Test flattening an already flat list.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, 2, 3, 4]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_empty_nested_lists(self): + """Test flattening with empty nested lists.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[[], [1], [], [2], []]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2] + + @pytest.mark.asyncio + async def test_mixed_types_preserved(self): + """Test that non-list types are preserved during flattening.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=["hello", [1, {"a": 1}], [True]]) + ): + results[name] = value + assert results["flattened_list"] == ["hello", 1, {"a": 1}, True] + + @pytest.mark.asyncio + async def test_original_depth_reported(self): + """Test that original nesting depth is correctly reported.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, [2, [3]]]) + ): + results[name] = value + assert results["original_depth"] == 3 + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# InterleaveListsBlock Manual Tests +# ============================================================================= + + +class TestInterleaveListsBlockManual: + """Manual test cases for InterleaveListsBlock.""" + + def setup_method(self): + self.block = InterleaveListsBlock() + + @pytest.mark.asyncio + async def test_equal_length_interleave(self): + """Test interleaving two equal-length lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2, 3], ["a", "b", "c"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", 2, "b", 3, "c"] + + @pytest.mark.asyncio + async def test_unequal_length_interleave(self): + """Test interleaving lists of different lengths.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2, 3, 4], ["a", "b"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", 2, "b", 3, 4] + + @pytest.mark.asyncio + async def test_three_lists_interleave(self): + """Test interleaving three lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2], ["a", "b"], ["x", "y"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", "x", 2, "b", "y"] + + @pytest.mark.asyncio + async def test_single_element_lists(self): + """Test interleaving single-element lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1], [2], [3], [4]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ZipListsBlock Manual Tests +# ============================================================================= + + +class TestZipListsBlockManual: + """Manual test cases for ZipListsBlock.""" + + def setup_method(self): + self.block = ZipListsBlock() + + @pytest.mark.asyncio + async def test_basic_zip(self): + """Test basic zipping of two lists.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2, 3], ["a", "b", "c"]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "b"], [3, "c"]] + + @pytest.mark.asyncio + async def test_truncate_to_shortest(self): + """Test that default behavior truncates to shortest list.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2, 3], ["a", "b"]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "b"]] + assert results["length"] == 2 + + @pytest.mark.asyncio + async def test_pad_to_longest(self): + """Test padding shorter lists with fill value.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input( + lists=[[1, 2, 3], ["a"]], + pad_to_longest=True, + fill_value="X", + ) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "X"], [3, "X"]] + + @pytest.mark.asyncio + async def test_pad_with_none(self): + """Test padding with None (default fill value).""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input( + lists=[[1, 2], ["a"]], + pad_to_longest=True, + ) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, None]] + + @pytest.mark.asyncio + async def test_three_lists_zip(self): + """Test zipping three lists.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2], ["a", "b"], [True, False]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a", True], [2, "b", False]] + + @pytest.mark.asyncio + async def test_empty_lists_zip(self): + """Test zipping empty input.""" + results = {} + async for name, value in self.block.run(ZipListsBlock.Input(lists=[])): + results[name] = value + assert results["zipped_list"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ListDifferenceBlock Manual Tests +# ============================================================================= + + +class TestListDifferenceBlockManual: + """Manual test cases for ListDifferenceBlock.""" + + def setup_method(self): + self.block = ListDifferenceBlock() + + @pytest.mark.asyncio + async def test_basic_difference(self): + """Test basic set difference.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3, 4, 5], + list_b=[3, 4, 5, 6, 7], + ) + ): + results[name] = value + assert results["difference"] == [1, 2] + + @pytest.mark.asyncio + async def test_symmetric_difference(self): + """Test symmetric difference.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[2, 3, 4], + symmetric=True, + ) + ): + results[name] = value + assert results["difference"] == [1, 4] + + @pytest.mark.asyncio + async def test_no_difference(self): + """Test when lists are identical.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[1, 2, 3], + ) + ): + results[name] = value + assert results["difference"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_complete_difference(self): + """Test when lists share no elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[4, 5, 6], + ) + ): + results[name] = value + assert results["difference"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_empty_list_a(self): + """Test with empty list_a.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input(list_a=[], list_b=[1, 2, 3]) + ): + results[name] = value + assert results["difference"] == [] + + @pytest.mark.asyncio + async def test_empty_list_b(self): + """Test with empty list_b.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input(list_a=[1, 2, 3], list_b=[]) + ): + results[name] = value + assert results["difference"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_string_difference(self): + """Test difference with string elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=["apple", "banana", "cherry"], + list_b=["banana", "date"], + ) + ): + results[name] = value + assert results["difference"] == ["apple", "cherry"] + + @pytest.mark.asyncio + async def test_dict_difference(self): + """Test difference with dictionary elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[{"a": 1}, {"b": 2}, {"c": 3}], + list_b=[{"b": 2}], + ) + ): + results[name] = value + assert results["difference"] == [{"a": 1}, {"c": 3}] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ListIntersectionBlock Manual Tests +# ============================================================================= + + +class TestListIntersectionBlockManual: + """Manual test cases for ListIntersectionBlock.""" + + def setup_method(self): + self.block = ListIntersectionBlock() + + @pytest.mark.asyncio + async def test_basic_intersection(self): + """Test basic intersection.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3, 4, 5], + list_b=[3, 4, 5, 6, 7], + ) + ): + results[name] = value + assert results["intersection"] == [3, 4, 5] + assert results["length"] == 3 + + @pytest.mark.asyncio + async def test_no_intersection(self): + """Test when lists share no elements.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3], + list_b=[4, 5, 6], + ) + ): + results[name] = value + assert results["intersection"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_identical_lists(self): + """Test intersection of identical lists.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3], + list_b=[1, 2, 3], + ) + ): + results[name] = value + assert results["intersection"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_preserves_order_from_list_a(self): + """Test that intersection preserves order from list_a.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[5, 3, 1], + list_b=[1, 3, 5], + ) + ): + results[name] = value + assert results["intersection"] == [5, 3, 1] + + @pytest.mark.asyncio + async def test_empty_list_a(self): + """Test with empty list_a.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input(list_a=[], list_b=[1, 2, 3]) + ): + results[name] = value + assert results["intersection"] == [] + + @pytest.mark.asyncio + async def test_empty_list_b(self): + """Test with empty list_b.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input(list_a=[1, 2, 3], list_b=[]) + ): + results[name] = value + assert results["intersection"] == [] + + @pytest.mark.asyncio + async def test_string_intersection(self): + """Test intersection with string elements.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=["apple", "banana", "cherry"], + list_b=["banana", "cherry", "date"], + ) + ): + results[name] = value + assert results["intersection"] == ["banana", "cherry"] + + @pytest.mark.asyncio + async def test_deduplication_in_intersection(self): + """Test that duplicates in input don't cause duplicate results.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 1, 2, 2, 3], + list_b=[1, 2], + ) + ): + results[name] = value + assert results["intersection"] == [1, 2] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# Block Method Tests +# ============================================================================= + + +class TestConcatenateListsBlockMethods: + """Tests for internal methods of ConcatenateListsBlock.""" + + def setup_method(self): + self.block = ConcatenateListsBlock() + + def test_validate_inputs_valid(self): + assert self.block._validate_inputs([[1], [2]]) is None + + def test_validate_inputs_invalid(self): + result = self.block._validate_inputs([[1], "bad"]) + assert result is not None + + def test_perform_concatenation(self): + result = self.block._perform_concatenation([[1, 2], [3, 4]]) + assert result == [1, 2, 3, 4] + + def test_apply_deduplication(self): + result = self.block._apply_deduplication([1, 2, 2, 3]) + assert result == [1, 2, 3] + + def test_apply_none_removal(self): + result = self.block._apply_none_removal([1, None, 2]) + assert result == [1, 2] + + def test_post_process_all_options(self): + result = self.block._post_process( + [1, None, 2, None, 2], deduplicate=True, remove_none=True + ) + assert result == [1, 2] + + def test_post_process_no_options(self): + result = self.block._post_process( + [1, None, 2, None, 2], deduplicate=False, remove_none=False + ) + assert result == [1, None, 2, None, 2] + + +class TestFlattenListBlockMethods: + """Tests for internal methods of FlattenListBlock.""" + + def setup_method(self): + self.block = FlattenListBlock() + + def test_compute_depth_flat(self): + assert self.block._compute_depth([1, 2, 3]) == 1 + + def test_compute_depth_nested(self): + assert self.block._compute_depth([[1, [2]]]) == 3 + + def test_flatten_unlimited(self): + result = self.block._flatten([1, [2, [3]]], max_depth=-1) + assert result == [1, 2, 3] + + def test_flatten_limited(self): + result = self.block._flatten([1, [2, [3]]], max_depth=1) + assert result == [1, 2, [3]] + + def test_validate_max_depth_valid(self): + assert self.block._validate_max_depth(-1) is None + assert self.block._validate_max_depth(0) is None + assert self.block._validate_max_depth(5) is None + + def test_validate_max_depth_invalid(self): + result = self.block._validate_max_depth(-2) + assert result is not None + + +class TestZipListsBlockMethods: + """Tests for internal methods of ZipListsBlock.""" + + def setup_method(self): + self.block = ZipListsBlock() + + def test_zip_truncate(self): + result = self.block._zip_truncate([[1, 2, 3], ["a", "b"]]) + assert result == [[1, "a"], [2, "b"]] + + def test_zip_pad(self): + result = self.block._zip_pad([[1, 2, 3], ["a"]], fill_value="X") + assert result == [[1, "a"], [2, "X"], [3, "X"]] + + def test_zip_pad_empty(self): + result = self.block._zip_pad([], fill_value=None) + assert result == [] + + def test_validate_inputs(self): + assert self.block._validate_inputs([[1], [2]]) is None + result = self.block._validate_inputs([[1], "bad"]) + assert result is not None + + +class TestListDifferenceBlockMethods: + """Tests for internal methods of ListDifferenceBlock.""" + + def setup_method(self): + self.block = ListDifferenceBlock() + + def test_compute_difference(self): + result = self.block._compute_difference([1, 2, 3], [2, 3, 4]) + assert result == [1] + + def test_compute_symmetric_difference(self): + result = self.block._compute_symmetric_difference([1, 2, 3], [2, 3, 4]) + assert result == [1, 4] + + def test_compute_difference_empty(self): + result = self.block._compute_difference([], [1, 2]) + assert result == [] + + def test_compute_symmetric_difference_identical(self): + result = self.block._compute_symmetric_difference([1, 2], [1, 2]) + assert result == [] + + +class TestListIntersectionBlockMethods: + """Tests for internal methods of ListIntersectionBlock.""" + + def setup_method(self): + self.block = ListIntersectionBlock() + + def test_compute_intersection(self): + result = self.block._compute_intersection([1, 2, 3], [2, 3, 4]) + assert result == [2, 3] + + def test_compute_intersection_empty(self): + result = self.block._compute_intersection([], [1, 2]) + assert result == [] + + def test_compute_intersection_no_overlap(self): + result = self.block._compute_intersection([1, 2], [3, 4]) + assert result == [] diff --git a/autogpt_platform/backend/test/chat/__init__.py b/autogpt_platform/backend/test/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/test/chat/test_security_hooks.py b/autogpt_platform/backend/test/chat/test_security_hooks.py new file mode 100644 index 0000000000..f10a90871b --- /dev/null +++ b/autogpt_platform/backend/test/chat/test_security_hooks.py @@ -0,0 +1,133 @@ +"""Tests for SDK security hooks — workspace paths, tool access, and deny messages. + +These are pure unit tests with no external dependencies (no SDK, no DB, no server). +They validate that the security hooks correctly block unauthorized paths, +tool access, and dangerous input patterns. + +Note: Bash command validation was removed — the SDK built-in Bash tool is not in +allowed_tools, and the bash_exec MCP tool has kernel-level network isolation +(unshare --net) making command-level parsing unnecessary. +""" + +from backend.api.features.chat.sdk.security_hooks import ( + _validate_tool_access, + _validate_workspace_path, +) + +SDK_CWD = "/tmp/copilot-test-session" + + +def _is_denied(result: dict) -> bool: + hook = result.get("hookSpecificOutput", {}) + return hook.get("permissionDecision") == "deny" + + +def _reason(result: dict) -> str: + return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "") + + +# ============================================================ +# Workspace path validation (Read, Write, Edit, etc.) +# ============================================================ + + +class TestWorkspacePathValidation: + def test_path_in_workspace(self): + result = _validate_workspace_path( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD + ) + assert not _is_denied(result) + + def test_path_outside_workspace(self): + result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) + assert _is_denied(result) + + def test_tool_results_allowed(self): + result = _validate_workspace_path( + "Read", + {"file_path": "~/.claude/projects/abc/tool-results/out.txt"}, + SDK_CWD, + ) + assert not _is_denied(result) + + def test_claude_settings_blocked(self): + result = _validate_workspace_path( + "Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD + ) + assert _is_denied(result) + + def test_claude_projects_without_tool_results(self): + result = _validate_workspace_path( + "Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD + ) + assert _is_denied(result) + + def test_no_path_allowed(self): + """Glob/Grep without path defaults to cwd — should be allowed.""" + result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD) + assert not _is_denied(result) + + def test_path_traversal_with_dotdot(self): + result = _validate_workspace_path( + "Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD + ) + assert _is_denied(result) + + +# ============================================================ +# Tool access validation +# ============================================================ + + +class TestToolAccessValidation: + def test_blocked_tools(self): + for tool in ("bash", "shell", "exec", "terminal", "command"): + result = _validate_tool_access(tool, {}) + assert _is_denied(result), f"Tool '{tool}' should be blocked" + + def test_bash_builtin_blocked(self): + """SDK built-in Bash (capital) is blocked as defence-in-depth.""" + result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD) + assert _is_denied(result) + assert "Bash" in _reason(result) + + def test_workspace_tools_delegate(self): + result = _validate_tool_access( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD + ) + assert not _is_denied(result) + + def test_dangerous_pattern_blocked(self): + result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"}) + assert _is_denied(result) + + def test_safe_unknown_tool_allowed(self): + result = _validate_tool_access("SomeSafeTool", {"data": "hello world"}) + assert not _is_denied(result) + + +# ============================================================ +# Deny message quality (ntindle feedback) +# ============================================================ + + +class TestDenyMessageClarity: + """Deny messages must include [SECURITY] and 'cannot be bypassed' + so the model knows the restriction is enforced, not a suggestion.""" + + def test_blocked_tool_message(self): + reason = _reason(_validate_tool_access("bash", {})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + def test_bash_builtin_blocked_message(self): + reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + def test_workspace_path_message(self): + reason = _reason( + _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) + ) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason diff --git a/autogpt_platform/backend/test/chat/test_transcript.py b/autogpt_platform/backend/test/chat/test_transcript.py new file mode 100644 index 0000000000..71b1fad81f --- /dev/null +++ b/autogpt_platform/backend/test/chat/test_transcript.py @@ -0,0 +1,255 @@ +"""Unit tests for JSONL transcript management utilities.""" + +import json +import os + +from backend.api.features.chat.sdk.transcript import ( + STRIPPABLE_TYPES, + read_transcript_file, + strip_progress_entries, + validate_transcript, + write_transcript_to_tempfile, +) + + +def _make_jsonl(*entries: dict) -> str: + return "\n".join(json.dumps(e) for e in entries) + "\n" + + +# --- Fixtures --- + + +METADATA_LINE = {"type": "queue-operation", "subtype": "create"} +FILE_HISTORY = {"type": "file-history-snapshot", "files": []} +USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}} +ASST_MSG = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": {"role": "assistant", "content": "hello"}, +} +PROGRESS_ENTRY = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress", "stdout": "running..."}, +} + +VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG) + + +# --- read_transcript_file --- + + +class TestReadTranscriptFile: + def test_returns_content_for_valid_file(self, tmp_path): + path = tmp_path / "session.jsonl" + path.write_text(VALID_TRANSCRIPT) + result = read_transcript_file(str(path)) + assert result is not None + assert "user" in result + + def test_returns_none_for_missing_file(self): + assert read_transcript_file("/nonexistent/path.jsonl") is None + + def test_returns_none_for_empty_path(self): + assert read_transcript_file("") is None + + def test_returns_none_for_empty_file(self, tmp_path): + path = tmp_path / "empty.jsonl" + path.write_text("") + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_metadata_only(self, tmp_path): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + path = tmp_path / "meta.jsonl" + path.write_text(content) + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_invalid_json(self, tmp_path): + path = tmp_path / "bad.jsonl" + path.write_text("not json\n{}\n{}\n") + assert read_transcript_file(str(path)) is None + + def test_no_size_limit(self, tmp_path): + """Large files are accepted — bucket storage has no size limit.""" + big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000} + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG) + path = tmp_path / "big.jsonl" + path.write_text(content) + result = read_transcript_file(str(path)) + assert result is not None + + +# --- write_transcript_to_tempfile --- + + +class TestWriteTranscriptToTempfile: + """Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check.""" + + def test_writes_file_and_returns_path(self): + cwd = "/tmp/copilot-test-write" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "sess-1234-abcd", cwd + ) + assert result is not None + assert os.path.isfile(result) + assert result.endswith(".jsonl") + with open(result) as f: + assert f.read() == VALID_TRANSCRIPT + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_creates_parent_directory(self): + cwd = "/tmp/copilot-test-mkdir" + try: + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is not None + assert os.path.isdir(cwd) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_uses_session_id_prefix(self): + cwd = "/tmp/copilot-test-prefix" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "abcdef12-rest", cwd + ) + assert result is not None + assert "abcdef12" in os.path.basename(result) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_rejects_cwd_outside_sandbox(self, tmp_path): + cwd = str(tmp_path / "not-copilot") + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is None + + +# --- validate_transcript --- + + +class TestValidateTranscript: + def test_valid_transcript(self): + assert validate_transcript(VALID_TRANSCRIPT) is True + + def test_none_content(self): + assert validate_transcript(None) is False + + def test_empty_content(self): + assert validate_transcript("") is False + + def test_metadata_only(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + assert validate_transcript(content) is False + + def test_user_only_no_assistant(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG) + assert validate_transcript(content) is False + + def test_assistant_only_no_user(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG) + assert validate_transcript(content) is False + + def test_invalid_json_returns_false(self): + assert validate_transcript("not json\n{}\n{}\n") is False + + +# --- strip_progress_entries --- + + +class TestStripProgressEntries: + def test_strips_all_strippable_types(self): + """All STRIPPABLE_TYPES are removed from the output.""" + entries = [ + USER_MSG, + {"type": "progress", "uuid": "p1", "parentUuid": "u1"}, + {"type": "file-history-snapshot", "files": []}, + {"type": "queue-operation", "subtype": "create"}, + {"type": "summary", "text": "..."}, + {"type": "pr-link", "url": "..."}, + ASST_MSG, + ] + result = strip_progress_entries(_make_jsonl(*entries)) + result_types = {json.loads(line)["type"] for line in result.strip().split("\n")} + assert result_types == {"user", "assistant"} + for stype in STRIPPABLE_TYPES: + assert stype not in result_types + + def test_reparents_children_of_stripped_entries(self): + """An assistant message whose parent is a progress entry gets reparented.""" + progress = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress"}, + } + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p1", # Points to progress + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, progress, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + # Should be reparented to u1 (the user message) + assert asst_entry["parentUuid"] == "u1" + + def test_reparents_through_chain(self): + """Reparenting walks through multiple stripped entries.""" + p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"} + p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"} + p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"} + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p3", # 3 levels deep + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, p1, p2, p3, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + assert asst_entry["parentUuid"] == "u1" + + def test_preserves_non_strippable_entries(self): + """User, assistant, and system entries are preserved.""" + system = {"type": "system", "uuid": "s1", "message": "prompt"} + content = _make_jsonl(system, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + assert result_types == ["system", "user", "assistant"] + + def test_empty_input(self): + result = strip_progress_entries("") + # Should return just a newline (empty content stripped) + assert result.strip() == "" + + def test_no_strippable_entries(self): + """When there's nothing to strip, output matches input structure.""" + content = _make_jsonl(USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_lines = result.strip().split("\n") + assert len(result_lines) == 2 + + def test_handles_entries_without_uuid(self): + """Entries without uuid field are handled gracefully.""" + no_uuid = {"type": "queue-operation", "subtype": "create"} + content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + # queue-operation is strippable + assert "queue-operation" not in result_types + assert "user" in result_types + assert "assistant" in result_types diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts new file mode 100644 index 0000000000..326f42e049 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts @@ -0,0 +1,96 @@ +import { NextResponse } from "next/server"; + +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + +// MCP-specific OAuth callback route. +// +// Unlike the generic oauth_callback which relies on window.opener.postMessage, +// this route uses BroadcastChannel as the PRIMARY communication method. +// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost) +// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers. +// +// BroadcastChannel works across all same-origin tabs/popups regardless of opener. +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const code = searchParams.get("code"); + const state = searchParams.get("state"); + + const success = Boolean(code && state); + const message = success + ? { success: true, code, state } + : { + success: false, + message: `Missing parameters: ${searchParams.toString()}`, + }; + + return new NextResponse( + ` + + MCP Sign-in + +
+
+

Completing sign-in...

+
+ + + +`, + { headers: { "Content-Type": "text/html" } }, + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunGraph/useRunGraph.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunGraph/useRunGraph.ts index 6980e95f11..51bb57057f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunGraph/useRunGraph.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunGraph/useRunGraph.ts @@ -4,7 +4,7 @@ import { } from "@/app/api/__generated__/endpoints/graphs/graphs"; import { useToast } from "@/components/molecules/Toast/use-toast"; import { parseAsInteger, parseAsString, useQueryStates } from "nuqs"; -import { GraphExecutionMeta } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/use-agent-runs"; +import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta"; import { useGraphStore } from "@/app/(platform)/build/stores/graphStore"; import { useShallow } from "zustand/react/shallow"; import { useEffect, useState } from "react"; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx index d4aa26480d..62e796b748 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx @@ -47,7 +47,10 @@ export type CustomNode = XYNode; export const CustomNode: React.FC> = React.memo( ({ data, id: nodeId, selected }) => { - const { inputSchema, outputSchema } = useCustomNode({ data, nodeId }); + const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({ + data, + nodeId, + }); const isAgent = data.uiType === BlockUIType.AGENT; @@ -98,6 +101,7 @@ export const CustomNode: React.FC> = React.memo( jsonSchema={preprocessInputSchema(inputSchema)} nodeId={nodeId} uiType={data.uiType} + isMCPWithTool={isMCPWithTool} className={cn( "bg-white px-4", isWebhook && "pointer-events-none opacity-50", diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index c4659b8dcf..9a3add62b6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -20,10 +20,8 @@ type Props = { export const NodeHeader = ({ data, nodeId }: Props) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); - const title = - (data.metadata?.customized_name as string) || - data.hardcodedValues?.agent_name || - data.title; + + const title = (data.metadata?.customized_name as string) || data.title; const [isEditingTitle, setIsEditingTitle] = useState(false); const [editedTitle, setEditedTitle] = useState(title); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx index e58d0ab12b..050515a02f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx @@ -3,6 +3,34 @@ import { CustomNodeData } from "./CustomNode"; import { BlockUIType } from "../../../types"; import { useMemo } from "react"; import { mergeSchemaForResolution } from "./helpers"; +/** + * Build a dynamic input schema for MCP blocks. + * + * When a tool has been selected (tool_input_schema is populated), the block + * renders the selected tool's input parameters *plus* the credentials field + * so users can select/change the OAuth credential used for execution. + * + * Static fields like server_url, selected_tool, available_tools, and + * tool_arguments are hidden because they're pre-configured from the dialog. + */ +function buildMCPInputSchema( + toolInputSchema: Record, + blockInputSchema: Record, +): Record { + // Extract the credentials field from the block's original input schema + const credentialsSchema = + blockInputSchema?.properties?.credentials ?? undefined; + + return { + type: "object", + properties: { + // Credentials field first so the dropdown appears at the top + ...(credentialsSchema ? { credentials: credentialsSchema } : {}), + ...(toolInputSchema.properties ?? {}), + }, + required: [...(toolInputSchema.required ?? [])], + }; +} export const useCustomNode = ({ data, @@ -19,10 +47,18 @@ export const useCustomNode = ({ ); const isAgent = data.uiType === BlockUIType.AGENT; + const isMCPWithTool = + data.uiType === BlockUIType.MCP_TOOL && + !!data.hardcodedValues?.tool_input_schema?.properties; const currentInputSchema = isAgent ? (data.hardcodedValues.input_schema ?? {}) - : data.inputSchema; + : isMCPWithTool + ? buildMCPInputSchema( + data.hardcodedValues.tool_input_schema, + data.inputSchema, + ) + : data.inputSchema; const currentOutputSchema = isAgent ? (data.hardcodedValues.output_schema ?? {}) : data.outputSchema; @@ -54,5 +90,6 @@ export const useCustomNode = ({ return { inputSchema, outputSchema, + isMCPWithTool, }; }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx index d6a3fabffa..77b21dda92 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx @@ -9,39 +9,72 @@ interface FormCreatorProps { jsonSchema: RJSFSchema; nodeId: string; uiType: BlockUIType; + /** When true the block is an MCP Tool with a selected tool. */ + isMCPWithTool?: boolean; showHandles?: boolean; className?: string; } export const FormCreator: React.FC = React.memo( - ({ jsonSchema, nodeId, uiType, showHandles = true, className }) => { + ({ + jsonSchema, + nodeId, + uiType, + isMCPWithTool = false, + showHandles = true, + className, + }) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); const getHardCodedValues = useNodeStore( (state) => state.getHardCodedValues, ); + const isAgent = uiType === BlockUIType.AGENT; + const handleChange = ({ formData }: any) => { if ("credentials" in formData && !formData.credentials?.id) { delete formData.credentials; } - const updatedValues = - uiType === BlockUIType.AGENT - ? { - ...getHardCodedValues(nodeId), - inputs: formData, - } - : formData; + let updatedValues; + if (isAgent) { + updatedValues = { + ...getHardCodedValues(nodeId), + inputs: formData, + }; + } else if (isMCPWithTool) { + // Separate credentials from tool arguments — credentials are stored + // at the top level of hardcodedValues, not inside tool_arguments. + const { credentials, ...toolArgs } = formData; + updatedValues = { + ...getHardCodedValues(nodeId), + tool_arguments: toolArgs, + ...(credentials?.id ? { credentials } : {}), + }; + } else { + updatedValues = formData; + } updateNodeData(nodeId, { hardcodedValues: updatedValues }); }; const hardcodedValues = getHardCodedValues(nodeId); - const initialValues = - uiType === BlockUIType.AGENT - ? (hardcodedValues.inputs ?? {}) - : hardcodedValues; + + let initialValues; + if (isAgent) { + initialValues = hardcodedValues.inputs ?? {}; + } else if (isMCPWithTool) { + // Merge tool arguments with credentials for the form + initialValues = { + ...(hardcodedValues.tool_arguments ?? {}), + ...(hardcodedValues.credentials?.id + ? { credentials: hardcodedValues.credentials } + : {}), + }; + } else { + initialValues = hardcodedValues; + } return (
; + availableTools: Record; + /** Credentials meta from OAuth flow, null for public servers. */ + credentials: CredentialsMetaInput | null; +}; + +interface MCPToolDialogProps { + open: boolean; + onClose: () => void; + onConfirm: (result: MCPToolDialogResult) => void; +} + +type DialogStep = "url" | "tool"; + +export function MCPToolDialog({ + open, + onClose, + onConfirm, +}: MCPToolDialogProps) { + const allProviders = useContext(CredentialsProvidersContext); + + const [step, setStep] = useState("url"); + const [serverUrl, setServerUrl] = useState(""); + const [tools, setTools] = useState([]); + const [serverName, setServerName] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [authRequired, setAuthRequired] = useState(false); + const [oauthLoading, setOauthLoading] = useState(false); + const [showManualToken, setShowManualToken] = useState(false); + const [manualToken, setManualToken] = useState(""); + const [selectedTool, setSelectedTool] = useState( + null, + ); + const [credentials, setCredentials] = useState( + null, + ); + + const startOAuthRef = useRef(false); + const oauthAbortRef = useRef<((reason?: string) => void) | null>(null); + + // Clean up on unmount + useEffect(() => { + return () => { + oauthAbortRef.current?.(); + }; + }, []); + + const reset = useCallback(() => { + oauthAbortRef.current?.(); + oauthAbortRef.current = null; + setStep("url"); + setServerUrl(""); + setManualToken(""); + setTools([]); + setServerName(null); + setLoading(false); + setError(null); + setAuthRequired(false); + setOauthLoading(false); + setShowManualToken(false); + setSelectedTool(null); + setCredentials(null); + }, []); + + const handleClose = useCallback(() => { + reset(); + onClose(); + }, [reset, onClose]); + + const discoverTools = useCallback(async (url: string, authToken?: string) => { + setLoading(true); + setError(null); + try { + const response = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: url, + auth_token: authToken || null, + }); + if (response.status !== 200) throw response.data; + setTools(response.data.tools); + setServerName(response.data.server_name ?? null); + setAuthRequired(false); + setShowManualToken(false); + setStep("tool"); + } catch (e: any) { + if (e?.status === 401 || e?.status === 403) { + setAuthRequired(true); + setError(null); + // Automatically start OAuth sign-in instead of requiring a second click + setLoading(false); + startOAuthRef.current = true; + return; + } else { + const message = + e?.message || e?.detail || "Failed to connect to MCP server"; + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setLoading(false); + } + }, []); + + const handleDiscoverTools = useCallback(() => { + if (!serverUrl.trim()) return; + discoverTools(serverUrl.trim(), manualToken.trim() || undefined); + }, [serverUrl, manualToken, discoverTools]); + + const handleOAuthSignIn = useCallback(async () => { + if (!serverUrl.trim()) return; + setError(null); + + // Abort any previous OAuth flow + oauthAbortRef.current?.(); + + setOauthLoading(true); + + try { + const loginResponse = await postV2InitiateOauthLoginForAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (loginResponse.status !== 200) throw loginResponse.data; + const { login_url, state_token } = loginResponse.data; + + const { promise, cleanup } = openOAuthPopup(login_url, { + stateToken: state_token, + useCrossOriginListeners: true, + }); + oauthAbortRef.current = cleanup.abort; + + const result = await promise; + + // Exchange code for tokens via the credentials provider (updates cache) + setLoading(true); + setOauthLoading(false); + + const mcpProvider = allProviders?.["mcp"]; + let callbackResult; + if (mcpProvider) { + callbackResult = await mcpProvider.mcpOAuthCallback( + result.code, + state_token, + ); + } else { + const cbResponse = await postV2ExchangeOauthCodeForMcpTokens({ + code: result.code, + state_token, + }); + if (cbResponse.status !== 200) throw cbResponse.data; + callbackResult = cbResponse.data; + } + + setCredentials({ + id: callbackResult.id, + provider: callbackResult.provider, + type: callbackResult.type, + title: callbackResult.title, + }); + setAuthRequired(false); + + // Discover tools now that we're authenticated + const toolsResponse = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (toolsResponse.status !== 200) throw toolsResponse.data; + setTools(toolsResponse.data.tools); + setServerName(toolsResponse.data.server_name ?? null); + setStep("tool"); + } catch (e: any) { + // If server doesn't support OAuth → show manual token entry + if (e?.status === 400) { + setShowManualToken(true); + setError( + "This server does not support OAuth sign-in. Please enter a token manually.", + ); + } else if (e?.message === "OAuth flow timed out") { + setError("OAuth sign-in timed out. Please try again."); + } else { + const status = e?.status; + let message: string; + if (status === 401 || status === 403) { + message = + "Authentication succeeded but the server still rejected the request. " + + "The token audience may not match. Please try again."; + } else { + message = e?.message || e?.detail || "Failed to complete sign-in"; + } + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setOauthLoading(false); + setLoading(false); + oauthAbortRef.current = null; + } + }, [serverUrl, allProviders]); + + // Auto-start OAuth sign-in when server returns 401/403 + useEffect(() => { + if (authRequired && startOAuthRef.current) { + startOAuthRef.current = false; + handleOAuthSignIn(); + } + }, [authRequired, handleOAuthSignIn]); + + const handleConfirm = useCallback(() => { + if (!selectedTool) return; + + const availableTools: Record = {}; + for (const t of tools) { + availableTools[t.name] = { + description: t.description, + input_schema: t.input_schema, + }; + } + + onConfirm({ + serverUrl: serverUrl.trim(), + serverName, + selectedTool: selectedTool.name, + toolInputSchema: selectedTool.input_schema, + availableTools, + credentials, + }); + reset(); + }, [ + selectedTool, + tools, + serverUrl, + serverName, + credentials, + onConfirm, + reset, + ]); + + return ( + !isOpen && handleClose()}> + + + + {step === "url" + ? "Connect to MCP Server" + : `Select a Tool${serverName ? ` — ${serverName}` : ""}`} + + + {step === "url" + ? "Enter the URL of an MCP server to discover its available tools." + : `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`} + + + + {step === "url" && ( +
+
+ + setServerUrl(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ + {/* Auth required: show manual token option */} + {authRequired && !showManualToken && ( + + )} + + {/* Manual token entry — only visible when expanded */} + {showManualToken && ( +
+ + setManualToken(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ )} + + {error &&

{error}

} +
+ )} + + {step === "tool" && ( + +
+ {tools.map((tool) => ( + setSelectedTool(tool)} + /> + ))} +
+
+ )} + + + {step === "tool" && ( + + )} + + {step === "url" && ( + + )} + {step === "tool" && ( + + )} + +
+
+ ); +} + +// --------------- Tool Card Component --------------- // + +/** Truncate a description to a reasonable length for the collapsed view. */ +function truncateDescription(text: string, maxLen = 120): string { + if (text.length <= maxLen) return text; + return text.slice(0, maxLen).trimEnd() + "…"; +} + +/** Pretty-print a JSON Schema type for a parameter. */ +function schemaTypeLabel(schema: Record): string { + if (schema.type) return schema.type; + if (schema.anyOf) + return schema.anyOf.map((s: any) => s.type ?? "any").join(" | "); + if (schema.oneOf) + return schema.oneOf.map((s: any) => s.type ?? "any").join(" | "); + return "any"; +} + +function MCPToolCard({ + tool, + selected, + onSelect, +}: { + tool: MCPToolResponse; + selected: boolean; + onSelect: () => void; +}) { + const [expanded, setExpanded] = useState(false); + const schema = tool.input_schema as Record; + const properties = schema?.properties ?? {}; + const required = new Set(schema?.required ?? []); + const paramNames = Object.keys(properties); + + // Strip XML-like tags from description for cleaner display. + // Loop to handle nested tags like ipt> (CodeQL fix). + let cleanDescription = tool.description ?? ""; + let prev = ""; + while (prev !== cleanDescription) { + prev = cleanDescription; + cleanDescription = cleanDescription.replace(/<[^>]*>/g, ""); + } + cleanDescription = cleanDescription.trim(); + + return ( + + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx index 10f4fc8a44..07c6795808 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx @@ -1,7 +1,7 @@ import { Button } from "@/components/__legacy__/ui/button"; import { Skeleton } from "@/components/__legacy__/ui/skeleton"; import { beautifyString, cn } from "@/lib/utils"; -import React, { ButtonHTMLAttributes } from "react"; +import React, { ButtonHTMLAttributes, useCallback, useState } from "react"; import { highlightText } from "./helpers"; import { PlusIcon } from "@phosphor-icons/react"; import { BlockInfo } from "@/app/api/__generated__/models/blockInfo"; @@ -9,6 +9,12 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore"; import { blockDragPreviewStyle } from "./style"; import { useReactFlow } from "@xyflow/react"; import { useNodeStore } from "../../../stores/nodeStore"; +import { BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api"; +import { + MCPToolDialog, + type MCPToolDialogResult, +} from "@/app/(platform)/build/components/MCPToolDialog"; + interface Props extends ButtonHTMLAttributes { title?: string; description?: string; @@ -33,22 +39,86 @@ export const Block: BlockComponent = ({ ); const { setViewport } = useReactFlow(); const { addBlock } = useNodeStore(); + const [mcpDialogOpen, setMcpDialogOpen] = useState(false); + + const isMCPBlock = blockData.uiType === BlockUIType.MCP_TOOL; + + const addBlockAndCenter = useCallback( + (block: BlockInfo, hardcodedValues?: Record) => { + const customNode = addBlock(block, hardcodedValues); + setTimeout(() => { + setViewport( + { + x: -customNode.position.x * 0.8 + window.innerWidth / 2, + y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, + zoom: 0.8, + }, + { duration: 500 }, + ); + }, 50); + return customNode; + }, + [addBlock, setViewport], + ); + + const updateNodeData = useNodeStore((state) => state.updateNodeData); + + const handleMCPToolConfirm = useCallback( + (result: MCPToolDialogResult) => { + // Derive a display label: prefer server name, fall back to URL hostname. + let serverLabel = result.serverName; + if (!serverLabel) { + try { + serverLabel = new URL(result.serverUrl).hostname; + } catch { + serverLabel = "MCP"; + } + } + + const customNode = addBlockAndCenter(blockData, { + server_url: result.serverUrl, + server_name: serverLabel, + selected_tool: result.selectedTool, + tool_input_schema: result.toolInputSchema, + available_tools: result.availableTools, + credentials: result.credentials ?? undefined, + }); + if (customNode) { + const title = result.selectedTool + ? `${serverLabel}: ${beautifyString(result.selectedTool)}` + : undefined; + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + credentials_optional: true, + ...(title && { customized_name: title }), + }, + }); + } + setMcpDialogOpen(false); + }, + [addBlockAndCenter, blockData, updateNodeData], + ); const handleClick = () => { - const customNode = addBlock(blockData); - setTimeout(() => { - setViewport( - { - x: -customNode.position.x * 0.8 + window.innerWidth / 2, - y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, - zoom: 0.8, + if (isMCPBlock) { + setMcpDialogOpen(true); + return; + } + const customNode = addBlockAndCenter(blockData); + // Set customized_name for agent blocks so the agent's name persists + if (customNode && blockData.id === SpecialBlockID.AGENT) { + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + customized_name: blockData.name, }, - { duration: 500 }, - ); - }, 50); + }); + } }; const handleDragStart = (e: React.DragEvent) => { + if (isMCPBlock) return; e.dataTransfer.effectAllowed = "copy"; e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData)); @@ -71,46 +141,56 @@ export const Block: BlockComponent = ({ : undefined; return ( -
- +
+ {title && ( + + {highlightText(beautifyString(title), highlightedText)} + + )} + {description && ( + + {highlightText(description, highlightedText)} + + )} +
+
+ +
+ + {isMCPBlock && ( + setMcpDialogOpen(false)} + onConfirm={handleMCPToolConfirm} + /> + )} + ); }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx index cb06a79683..f7d59a5693 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx @@ -1,6 +1,6 @@ import { useCallback } from "react"; -import { AgentRunDraftView } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view"; +import { AgentRunDraftView } from "@/app/(platform)/build/components/legacy-builder/agent-run-draft-view"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; import type { CredentialsMetaInput, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/SaveControl.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/SaveControl.tsx index dcaa0f6264..3ee5217354 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/SaveControl.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/SaveControl.tsx @@ -18,7 +18,7 @@ import { import { useToast } from "@/components/molecules/Toast/use-toast"; import { useQueryClient } from "@tanstack/react-query"; import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store"; -import { CronExpressionDialog } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/cron-scheduler-dialog"; +import { CronExpressionDialog } from "@/components/contextual/CronScheduler/cron-scheduler-dialog"; import { humanizeCronExpression } from "@/lib/cron-expression-utils"; import { CalendarClockIcon } from "lucide-react"; diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/agent-run-draft-view.tsx similarity index 99% rename from autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx rename to autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/agent-run-draft-view.tsx index b0c3a6ff7b..372d479299 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/agent-run-draft-view.tsx @@ -20,7 +20,7 @@ import { import { useBackendAPI } from "@/lib/autogpt-server-api/context"; import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs"; -import { ScheduleTaskDialog } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/cron-scheduler-dialog"; +import { ScheduleTaskDialog } from "@/components/contextual/CronScheduler/cron-scheduler-dialog"; import ActionButtonGroup from "@/components/__legacy__/action-button-group"; import type { ButtonAction } from "@/components/__legacy__/types"; import { @@ -53,7 +53,10 @@ import { ClockIcon, CopyIcon, InfoIcon } from "@phosphor-icons/react"; import { CalendarClockIcon, Trash2Icon } from "lucide-react"; import { analytics } from "@/services/analytics"; -import { AgentStatus, AgentStatusChip } from "./agent-status-chip"; +import { + AgentStatus, + AgentStatusChip, +} from "@/app/(platform)/build/components/legacy-builder/agent-status-chip"; export function AgentRunDraftView({ graph, diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-status-chip.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/agent-status-chip.tsx similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-status-chip.tsx rename to autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/agent-status-chip.tsx diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts index 2fde427330..0f5021351d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts @@ -9,4 +9,5 @@ export enum BlockUIType { AGENT = "Agent", AI = "AI", AYRSHARE = "Ayrshare", + MCP_TOOL = "MCP Tool", } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx index 71ade81a9f..c118057963 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx @@ -15,11 +15,16 @@ import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai"; import { useEffect, useRef, useState } from "react"; import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../../tools/RunAgent/RunAgent"; import { RunBlockTool } from "../../tools/RunBlock/RunBlock"; import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs"; +import { GenericTool } from "../../tools/GenericTool/GenericTool"; import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput"; // --------------------------------------------------------------------------- @@ -254,7 +259,31 @@ export const ChatMessagesContainer = ({ part={part as ToolUIPart} /> ); + case "tool-search_feature_requests": + return ( + + ); + case "tool-create_feature_request": + return ( + + ); default: + // Render a generic tool indicator for SDK built-in + // tools (Read, Glob, Grep, etc.) or any unrecognized tool + if (part.type.startsWith("tool-")) { + return ( + + ); + } return null; } })} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx index 6030665f1c..8a35f939ca 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx @@ -14,6 +14,10 @@ import { Text } from "@/components/atoms/Text/Text"; import { CopilotChatActionsProvider } from "../components/CopilotChatActionsProvider/CopilotChatActionsProvider"; import { CreateAgentTool } from "../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../tools/RunAgent/RunAgent"; @@ -45,6 +49,8 @@ const SECTIONS = [ "Tool: Create Agent", "Tool: Edit Agent", "Tool: View Agent Output", + "Tool: Search Feature Requests", + "Tool: Create Feature Request", "Full Conversation Example", ] as const; @@ -1421,6 +1427,235 @@ export default function StyleguidePage() { + {/* ============================================================= */} + {/* SEARCH FEATURE REQUESTS */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ + {/* ============================================================= */} + {/* CREATE FEATURE REQUEST */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ {/* ============================================================= */} {/* FULL CONVERSATION EXAMPLE */} {/* ============================================================= */} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx new file mode 100644 index 0000000000..fcd4624b6a --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx @@ -0,0 +1,227 @@ +"use client"; + +import type { ToolUIPart } from "ai"; +import { useMemo } from "react"; + +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; +import { + ContentBadge, + ContentCard, + ContentCardDescription, + ContentCardHeader, + ContentCardTitle, + ContentGrid, + ContentMessage, + ContentSuggestionsList, +} from "../../components/ToolAccordion/AccordionContent"; +import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion"; +import { + AccordionIcon, + getAccordionTitle, + getAnimationText, + getFeatureRequestOutput, + isCreatedOutput, + isErrorOutput, + isNoResultsOutput, + isSearchResultsOutput, + ToolIcon, + type FeatureRequestToolType, +} from "./helpers"; + +export interface FeatureRequestToolPart { + type: FeatureRequestToolType; + toolCallId: string; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +interface Props { + part: FeatureRequestToolPart; +} + +function truncate(text: string, maxChars: number): string { + const trimmed = text.trim(); + if (trimmed.length <= maxChars) return trimmed; + return `${trimmed.slice(0, maxChars).trimEnd()}…`; +} + +export function SearchFeatureRequestsTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const searchOutput = + isOutputAvailable && output && isSearchResultsOutput(output) + ? output + : null; + const noResultsOutput = + isOutputAvailable && output && isNoResultsOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && + ((!!searchOutput && searchOutput.count > 0) || + !!noResultsOutput || + !!errorOutput); + + const accordionDescription = + hasExpandableContent && searchOutput + ? `Found ${searchOutput.count} result${searchOutput.count === 1 ? "" : "s"} for "${searchOutput.query}"` + : hasExpandableContent && (noResultsOutput || errorOutput) + ? ((noResultsOutput ?? errorOutput)?.message ?? null) + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {searchOutput && ( + + {searchOutput.results.map((r) => ( + + + {r.title} + + {r.description && ( + + {truncate(r.description, 200)} + + )} + + ))} + + )} + + {noResultsOutput && ( +
+ {noResultsOutput.message} + {noResultsOutput.suggestions && + noResultsOutput.suggestions.length > 0 && ( + + )} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} + +export function CreateFeatureRequestTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const createdOutput = + isOutputAvailable && output && isCreatedOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && (!!createdOutput || !!errorOutput); + + const accordionDescription = + hasExpandableContent && createdOutput + ? createdOutput.issue_title + : hasExpandableContent && errorOutput + ? errorOutput.message + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {createdOutput && ( + + + {createdOutput.issue_title} + +
+ + {createdOutput.is_new_issue ? "New" : "Existing"} + +
+ {createdOutput.message} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx new file mode 100644 index 0000000000..75133905b1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx @@ -0,0 +1,271 @@ +import { + CheckCircleIcon, + LightbulbIcon, + MagnifyingGlassIcon, + PlusCircleIcon, +} from "@phosphor-icons/react"; +import type { ToolUIPart } from "ai"; + +/* ------------------------------------------------------------------ */ +/* Types (local until API client is regenerated) */ +/* ------------------------------------------------------------------ */ + +interface FeatureRequestInfo { + id: string; + identifier: string; + title: string; + description?: string | null; +} + +export interface FeatureRequestSearchResponse { + type: "feature_request_search"; + message: string; + results: FeatureRequestInfo[]; + count: number; + query: string; +} + +export interface FeatureRequestCreatedResponse { + type: "feature_request_created"; + message: string; + issue_id: string; + issue_identifier: string; + issue_title: string; + issue_url: string; + is_new_issue: boolean; + customer_name: string; +} + +interface NoResultsResponse { + type: "no_results"; + message: string; + suggestions?: string[]; +} + +interface ErrorResponse { + type: "error"; + message: string; + error?: string; +} + +export type FeatureRequestOutput = + | FeatureRequestSearchResponse + | FeatureRequestCreatedResponse + | NoResultsResponse + | ErrorResponse; + +export type FeatureRequestToolType = + | "tool-search_feature_requests" + | "tool-create_feature_request" + | string; + +/* ------------------------------------------------------------------ */ +/* Output parsing */ +/* ------------------------------------------------------------------ */ + +function parseOutput(output: unknown): FeatureRequestOutput | null { + if (!output) return null; + if (typeof output === "string") { + const trimmed = output.trim(); + if (!trimmed) return null; + try { + return parseOutput(JSON.parse(trimmed) as unknown); + } catch { + return null; + } + } + if (typeof output === "object") { + const type = (output as { type?: unknown }).type; + if ( + type === "feature_request_search" || + type === "feature_request_created" || + type === "no_results" || + type === "error" + ) { + return output as FeatureRequestOutput; + } + // Fallback structural checks + if ("results" in output && "query" in output) + return output as FeatureRequestSearchResponse; + if ("issue_identifier" in output) + return output as FeatureRequestCreatedResponse; + if ("suggestions" in output && !("error" in output)) + return output as NoResultsResponse; + if ("error" in output || "details" in output) + return output as ErrorResponse; + } + return null; +} + +export function getFeatureRequestOutput( + part: unknown, +): FeatureRequestOutput | null { + if (!part || typeof part !== "object") return null; + return parseOutput((part as { output?: unknown }).output); +} + +/* ------------------------------------------------------------------ */ +/* Type guards */ +/* ------------------------------------------------------------------ */ + +export function isSearchResultsOutput( + output: FeatureRequestOutput, +): output is FeatureRequestSearchResponse { + return ( + output.type === "feature_request_search" || + ("results" in output && "query" in output) + ); +} + +export function isCreatedOutput( + output: FeatureRequestOutput, +): output is FeatureRequestCreatedResponse { + return ( + output.type === "feature_request_created" || "issue_identifier" in output + ); +} + +export function isNoResultsOutput( + output: FeatureRequestOutput, +): output is NoResultsResponse { + return ( + output.type === "no_results" || + ("suggestions" in output && !("error" in output)) + ); +} + +export function isErrorOutput( + output: FeatureRequestOutput, +): output is ErrorResponse { + return output.type === "error" || "error" in output; +} + +/* ------------------------------------------------------------------ */ +/* Accordion metadata */ +/* ------------------------------------------------------------------ */ + +export function getAccordionTitle( + toolType: FeatureRequestToolType, + output: FeatureRequestOutput, +): string { + if (toolType === "tool-search_feature_requests") { + if (isSearchResultsOutput(output)) return "Feature requests"; + if (isNoResultsOutput(output)) return "No feature requests found"; + return "Feature request search error"; + } + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to feature request"; + } + if (isErrorOutput(output)) return "Feature request error"; + return "Feature request"; +} + +/* ------------------------------------------------------------------ */ +/* Animation text */ +/* ------------------------------------------------------------------ */ + +interface AnimationPart { + type: FeatureRequestToolType; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +export function getAnimationText(part: AnimationPart): string { + if (part.type === "tool-search_feature_requests") { + const query = (part.input as { query?: string } | undefined)?.query?.trim(); + const queryText = query ? ` for "${query}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Searching feature requests${queryText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Searching feature requests${queryText}`; + if (isSearchResultsOutput(output)) { + return `Found ${output.count} feature request${output.count === 1 ? "" : "s"}${queryText}`; + } + if (isNoResultsOutput(output)) + return `No feature requests found${queryText}`; + return `Error searching feature requests${queryText}`; + } + case "output-error": + return `Error searching feature requests${queryText}`; + default: + return "Searching feature requests"; + } + } + + // create_feature_request + const title = (part.input as { title?: string } | undefined)?.title?.trim(); + const titleText = title ? ` "${title}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Creating feature request${titleText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Creating feature request${titleText}`; + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to existing feature request"; + } + if (isErrorOutput(output)) return "Error creating feature request"; + return `Created feature request${titleText}`; + } + case "output-error": + return "Error creating feature request"; + default: + return "Creating feature request"; + } +} + +/* ------------------------------------------------------------------ */ +/* Icons */ +/* ------------------------------------------------------------------ */ + +export function ToolIcon({ + toolType, + isStreaming, + isError, +}: { + toolType: FeatureRequestToolType; + isStreaming?: boolean; + isError?: boolean; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? PlusCircleIcon + : MagnifyingGlassIcon; + + return ( + + ); +} + +export function AccordionIcon({ + toolType, +}: { + toolType: FeatureRequestToolType; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? CheckCircleIcon + : LightbulbIcon; + return ; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx new file mode 100644 index 0000000000..677f1d01d1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx @@ -0,0 +1,63 @@ +"use client"; + +import { ToolUIPart } from "ai"; +import { GearIcon } from "@phosphor-icons/react"; +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; + +interface Props { + part: ToolUIPart; +} + +function extractToolName(part: ToolUIPart): string { + // ToolUIPart.type is "tool-{name}", extract the name portion. + return part.type.replace(/^tool-/, ""); +} + +function formatToolName(name: string): string { + // "search_docs" → "Search docs", "Read" → "Read" + return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase()); +} + +function getAnimationText(part: ToolUIPart): string { + const label = formatToolName(extractToolName(part)); + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Running ${label}…`; + case "output-available": + return `${label} completed`; + case "output-error": + return `${label} failed`; + default: + return `Running ${label}…`; + } +} + +export function GenericTool({ part }: Props) { + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = part.state === "output-error"; + + return ( +
+
+ + +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/OldAgentLibraryView.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/OldAgentLibraryView.tsx deleted file mode 100644 index 54cc07878d..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/OldAgentLibraryView.tsx +++ /dev/null @@ -1,631 +0,0 @@ -"use client"; -import { useParams, useRouter } from "next/navigation"; -import { useQueryState } from "nuqs"; -import React, { - useCallback, - useEffect, - useMemo, - useRef, - useState, -} from "react"; - -import { - Graph, - GraphExecution, - GraphExecutionID, - GraphExecutionMeta, - GraphID, - LibraryAgent, - LibraryAgentID, - LibraryAgentPreset, - LibraryAgentPresetID, - Schedule, - ScheduleID, -} from "@/lib/autogpt-server-api"; -import { useBackendAPI } from "@/lib/autogpt-server-api/context"; -import { exportAsJSONFile } from "@/lib/utils"; - -import DeleteConfirmDialog from "@/components/__legacy__/delete-confirm-dialog"; -import type { ButtonAction } from "@/components/__legacy__/types"; -import { Button } from "@/components/__legacy__/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, -} from "@/components/__legacy__/ui/dialog"; -import LoadingBox, { LoadingSpinner } from "@/components/__legacy__/ui/loading"; -import { - useToast, - useToastOnFail, -} from "@/components/molecules/Toast/use-toast"; -import { AgentRunDetailsView } from "./components/agent-run-details-view"; -import { AgentRunDraftView } from "./components/agent-run-draft-view"; -import { CreatePresetDialog } from "./components/create-preset-dialog"; -import { useAgentRunsInfinite } from "./use-agent-runs"; -import { AgentRunsSelectorList } from "./components/agent-runs-selector-list"; -import { AgentScheduleDetailsView } from "./components/agent-schedule-details-view"; - -export function OldAgentLibraryView() { - const { id: agentID }: { id: LibraryAgentID } = useParams(); - const [executionId, setExecutionId] = useQueryState("executionId"); - const toastOnFail = useToastOnFail(); - const { toast } = useToast(); - const router = useRouter(); - const api = useBackendAPI(); - - // ============================ STATE ============================= - - const [graph, setGraph] = useState(null); // Graph version corresponding to LibraryAgent - const [agent, setAgent] = useState(null); - const agentRunsQuery = useAgentRunsInfinite(graph?.id); // only runs once graph.id is known - const agentRuns = agentRunsQuery.agentRuns; - const [agentPresets, setAgentPresets] = useState([]); - const [schedules, setSchedules] = useState([]); - const [selectedView, selectView] = useState< - | { type: "run"; id?: GraphExecutionID } - | { type: "preset"; id: LibraryAgentPresetID } - | { type: "schedule"; id: ScheduleID } - >({ type: "run" }); - const [selectedRun, setSelectedRun] = useState< - GraphExecution | GraphExecutionMeta | null - >(null); - const selectedSchedule = - selectedView.type == "schedule" - ? schedules.find((s) => s.id == selectedView.id) - : null; - const [isFirstLoad, setIsFirstLoad] = useState(true); - const [agentDeleteDialogOpen, setAgentDeleteDialogOpen] = - useState(false); - const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] = - useState(null); - const [confirmingDeleteAgentPreset, setConfirmingDeleteAgentPreset] = - useState(null); - const [copyAgentDialogOpen, setCopyAgentDialogOpen] = useState(false); - const [creatingPresetFromExecutionID, setCreatingPresetFromExecutionID] = - useState(null); - - // Set page title with agent name - useEffect(() => { - if (agent) { - document.title = `${agent.name} - Library - AutoGPT Platform`; - } - }, [agent]); - - const openRunDraftView = useCallback(() => { - selectView({ type: "run" }); - }, []); - - const selectRun = useCallback((id: GraphExecutionID) => { - selectView({ type: "run", id }); - }, []); - - const selectPreset = useCallback((id: LibraryAgentPresetID) => { - selectView({ type: "preset", id }); - }, []); - - const selectSchedule = useCallback((id: ScheduleID) => { - selectView({ type: "schedule", id }); - }, []); - - const graphVersions = useRef>({}); - const loadingGraphVersions = useRef>>({}); - const getGraphVersion = useCallback( - async (graphID: GraphID, version: number) => { - if (version in graphVersions.current) - return graphVersions.current[version]; - if (version in loadingGraphVersions.current) - return loadingGraphVersions.current[version]; - - const pendingGraph = api.getGraph(graphID, version).then((graph) => { - graphVersions.current[version] = graph; - return graph; - }); - // Cache promise as well to avoid duplicate requests - loadingGraphVersions.current[version] = pendingGraph; - return pendingGraph; - }, - [api, graphVersions, loadingGraphVersions], - ); - - const lastRefresh = useRef(0); - const refreshPageData = useCallback(() => { - if (Date.now() - lastRefresh.current < 2e3) return; // 2 second debounce - lastRefresh.current = Date.now(); - - api.getLibraryAgent(agentID).then((agent) => { - setAgent(agent); - - getGraphVersion(agent.graph_id, agent.graph_version).then( - (_graph) => - (graph && graph.version == _graph.version) || setGraph(_graph), - ); - Promise.all([ - agentRunsQuery.refetchRuns(), - api.listLibraryAgentPresets({ - graph_id: agent.graph_id, - page_size: 100, - }), - ]).then(([runsQueryResult, presets]) => { - setAgentPresets(presets.presets); - - const newestAgentRunsResponse = runsQueryResult.data?.pages[0]; - if (!newestAgentRunsResponse || newestAgentRunsResponse.status != 200) - return; - const newestAgentRuns = newestAgentRunsResponse.data.executions; - // Preload the corresponding graph versions for the latest 10 runs - new Set( - newestAgentRuns.slice(0, 10).map((run) => run.graph_version), - ).forEach((version) => getGraphVersion(agent.graph_id, version)); - }); - }); - }, [api, agentID, getGraphVersion, graph]); - - // On first load: select the latest run - useEffect(() => { - // Only for first load or first execution - if (selectedView.id || !isFirstLoad) return; - if (agentRuns.length == 0 && agentPresets.length == 0) return; - - setIsFirstLoad(false); - if (agentRuns.length > 0) { - // select latest run - const latestRun = agentRuns.reduce((latest, current) => { - if (!latest.started_at && !current.started_at) return latest; - if (!latest.started_at) return current; - if (!current.started_at) return latest; - return latest.started_at > current.started_at ? latest : current; - }, agentRuns[0]); - selectRun(latestRun.id as GraphExecutionID); - } else { - // select top preset - const latestPreset = agentPresets.toSorted( - (a, b) => b.updated_at.getTime() - a.updated_at.getTime(), - )[0]; - selectPreset(latestPreset.id); - } - }, [ - isFirstLoad, - selectedView.id, - agentRuns, - agentPresets, - selectRun, - selectPreset, - ]); - - useEffect(() => { - if (executionId) { - selectRun(executionId as GraphExecutionID); - setExecutionId(null); - } - }, [executionId, selectRun, setExecutionId]); - - // Initial load - useEffect(() => { - refreshPageData(); - - // Show a toast when the WebSocket connection disconnects - let connectionToast: ReturnType | null = null; - const cancelDisconnectHandler = api.onWebSocketDisconnect(() => { - connectionToast ??= toast({ - title: "Connection to server was lost", - variant: "destructive", - description: ( -
- Trying to reconnect... - -
- ), - duration: Infinity, - dismissable: true, - }); - }); - const cancelConnectHandler = api.onWebSocketConnect(() => { - if (connectionToast) - connectionToast.update({ - id: connectionToast.id, - title: "✅ Connection re-established", - variant: "default", - description: ( -
- Refreshing data... - -
- ), - duration: 2000, - dismissable: true, - }); - connectionToast = null; - }); - return () => { - cancelDisconnectHandler(); - cancelConnectHandler(); - }; - }, []); - - // Subscribe to WebSocket updates for agent runs - useEffect(() => { - if (!agent?.graph_id) return; - - return api.onWebSocketConnect(() => { - refreshPageData(); // Sync up on (re)connect - - // Subscribe to all executions for this agent - api.subscribeToGraphExecutions(agent.graph_id); - }); - }, [api, agent?.graph_id, refreshPageData]); - - // Handle execution updates - useEffect(() => { - const detachExecUpdateHandler = api.onWebSocketMessage( - "graph_execution_event", - (data) => { - if (data.graph_id != agent?.graph_id) return; - - agentRunsQuery.upsertAgentRun(data); - if (data.id === selectedView.id) { - // Update currently viewed run - setSelectedRun(data); - } - }, - ); - - return () => { - detachExecUpdateHandler(); - }; - }, [api, agent?.graph_id, selectedView.id]); - - // Pre-load selectedRun based on selectedView - useEffect(() => { - if (selectedView.type != "run" || !selectedView.id) return; - - const newSelectedRun = agentRuns.find((run) => run.id == selectedView.id); - if (selectedView.id !== selectedRun?.id) { - // Pull partial data from "cache" while waiting for the rest to load - setSelectedRun((newSelectedRun as GraphExecutionMeta) ?? null); - } - }, [api, selectedView, agentRuns, selectedRun?.id]); - - // Load selectedRun based on selectedView; refresh on agent refresh - useEffect(() => { - if (selectedView.type != "run" || !selectedView.id || !agent) return; - - api - .getGraphExecutionInfo(agent.graph_id, selectedView.id) - .then(async (run) => { - // Ensure corresponding graph version is available before rendering I/O - await getGraphVersion(run.graph_id, run.graph_version); - setSelectedRun(run); - }); - }, [api, selectedView, agent, getGraphVersion]); - - const fetchSchedules = useCallback(async () => { - if (!agent) return; - - setSchedules(await api.listGraphExecutionSchedules(agent.graph_id)); - }, [api, agent?.graph_id]); - - useEffect(() => { - fetchSchedules(); - }, [fetchSchedules]); - - // =========================== ACTIONS ============================ - - const deleteRun = useCallback( - async (run: GraphExecutionMeta) => { - if (run.status == "RUNNING" || run.status == "QUEUED") { - await api.stopGraphExecution(run.graph_id, run.id); - } - await api.deleteGraphExecution(run.id); - - setConfirmingDeleteAgentRun(null); - if (selectedView.type == "run" && selectedView.id == run.id) { - openRunDraftView(); - } - agentRunsQuery.removeAgentRun(run.id); - }, - [api, selectedView, openRunDraftView], - ); - - const deletePreset = useCallback( - async (presetID: LibraryAgentPresetID) => { - await api.deleteLibraryAgentPreset(presetID); - - setConfirmingDeleteAgentPreset(null); - if (selectedView.type == "preset" && selectedView.id == presetID) { - openRunDraftView(); - } - setAgentPresets((presets) => presets.filter((p) => p.id !== presetID)); - }, - [api, selectedView, openRunDraftView], - ); - - const deleteSchedule = useCallback( - async (scheduleID: ScheduleID) => { - const removedSchedule = - await api.deleteGraphExecutionSchedule(scheduleID); - - setSchedules((schedules) => { - const newSchedules = schedules.filter( - (s) => s.id !== removedSchedule.id, - ); - if ( - selectedView.type == "schedule" && - selectedView.id == removedSchedule.id - ) { - if (newSchedules.length > 0) { - // Select next schedule if available - selectSchedule(newSchedules[0].id); - } else { - // Reset to draft view if current schedule was deleted - openRunDraftView(); - } - } - return newSchedules; - }); - openRunDraftView(); - }, - [schedules, api], - ); - - const handleCreatePresetFromRun = useCallback( - async (name: string, description: string) => { - if (!creatingPresetFromExecutionID) return; - - await api - .createLibraryAgentPreset({ - name, - description, - graph_execution_id: creatingPresetFromExecutionID, - }) - .then((preset) => { - setAgentPresets((prev) => [...prev, preset]); - selectPreset(preset.id); - setCreatingPresetFromExecutionID(null); - }) - .catch(toastOnFail("create a preset")); - }, - [api, creatingPresetFromExecutionID, selectPreset, toast], - ); - - const downloadGraph = useCallback( - async () => - agent && - // Export sanitized graph from backend - api - .getGraph(agent.graph_id, agent.graph_version, true) - .then((graph) => - exportAsJSONFile(graph, `${graph.name}_v${graph.version}.json`), - ), - [api, agent], - ); - - const copyAgent = useCallback(async () => { - setCopyAgentDialogOpen(false); - api - .forkLibraryAgent(agentID) - .then((newAgent) => { - router.push(`/library/agents/${newAgent.id}`); - }) - .catch((error) => { - console.error("Error copying agent:", error); - toast({ - title: "Error copying agent", - description: `An error occurred while copying the agent: ${error.message}`, - variant: "destructive", - }); - }); - }, [agentID, api, router, toast]); - - const agentActions: ButtonAction[] = useMemo( - () => [ - { - label: "Customize agent", - href: `/build?flowID=${agent?.graph_id}&flowVersion=${agent?.graph_version}`, - disabled: !agent?.can_access_graph, - }, - { label: "Export agent to file", callback: downloadGraph }, - ...(!agent?.can_access_graph - ? [ - { - label: "Edit a copy", - callback: () => setCopyAgentDialogOpen(true), - }, - ] - : []), - { - label: "Delete agent", - callback: () => setAgentDeleteDialogOpen(true), - }, - ], - [agent, downloadGraph], - ); - - const runGraph = - graphVersions.current[selectedRun?.graph_version ?? 0] ?? graph; - - const onCreateSchedule = useCallback( - (schedule: Schedule) => { - setSchedules((prev) => [...prev, schedule]); - selectSchedule(schedule.id); - }, - [selectView], - ); - - const onCreatePreset = useCallback( - (preset: LibraryAgentPreset) => { - setAgentPresets((prev) => [...prev, preset]); - selectPreset(preset.id); - }, - [selectPreset], - ); - - const onUpdatePreset = useCallback( - (updated: LibraryAgentPreset) => { - setAgentPresets((prev) => - prev.map((p) => (p.id === updated.id ? updated : p)), - ); - selectPreset(updated.id); - }, - [selectPreset], - ); - - if (!agent || !graph) { - return ; - } - - return ( -
- {/* Sidebar w/ list of runs */} - {/* TODO: render this below header in sm and md layouts */} - - -
- {/* Header */} -
-

- { - agent.name /* TODO: use dynamic/custom run title - https://github.com/Significant-Gravitas/AutoGPT/issues/9184 */ - } -

-
- - {/* Run / Schedule views */} - {(selectedView.type == "run" && selectedView.id ? ( - selectedRun && runGraph ? ( - setConfirmingDeleteAgentRun(selectedRun)} - doCreatePresetFromRun={() => - setCreatingPresetFromExecutionID(selectedRun.id) - } - /> - ) : null - ) : selectedView.type == "run" ? ( - /* Draft new runs / Create new presets */ - - ) : selectedView.type == "preset" ? ( - /* Edit & update presets */ - preset.id == selectedView.id)! - } - onRun={selectRun} - recommendedScheduleCron={agent?.recommended_schedule_cron || null} - onCreateSchedule={onCreateSchedule} - onUpdatePreset={onUpdatePreset} - doDeletePreset={setConfirmingDeleteAgentPreset} - agentActions={agentActions} - /> - ) : selectedView.type == "schedule" ? ( - selectedSchedule && - graph && ( - - ) - ) : null) || } - - - agent && - api.deleteLibraryAgent(agent.id).then(() => router.push("/library")) - } - /> - - !open && setConfirmingDeleteAgentRun(null)} - onDoDelete={() => - confirmingDeleteAgentRun && deleteRun(confirmingDeleteAgentRun) - } - /> - !open && setConfirmingDeleteAgentPreset(null)} - onDoDelete={() => - confirmingDeleteAgentPreset && - deletePreset(confirmingDeleteAgentPreset) - } - /> - {/* Copy agent confirmation dialog */} - - - - You're making an editable copy - - The original Marketplace agent stays the same and cannot be - edited. We'll save a new version of this agent to your - Library. From there, you can customize it however you'd - like by clicking "Customize agent" — this will open - the builder where you can see and modify the inner workings. - - - - - - - - - setCreatingPresetFromExecutionID(null)} - onConfirm={handleCreatePresetFromRun} - /> -
-
- ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-details-view.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-details-view.tsx deleted file mode 100644 index eb5224c958..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-details-view.tsx +++ /dev/null @@ -1,445 +0,0 @@ -"use client"; -import { format, formatDistanceToNow, formatDistanceStrict } from "date-fns"; -import React, { useCallback, useMemo, useEffect } from "react"; - -import { - Graph, - GraphExecution, - GraphExecutionID, - GraphExecutionMeta, - LibraryAgent, -} from "@/lib/autogpt-server-api"; -import { useBackendAPI } from "@/lib/autogpt-server-api/context"; - -import ActionButtonGroup from "@/components/__legacy__/action-button-group"; -import type { ButtonAction } from "@/components/__legacy__/types"; -import { - Card, - CardContent, - CardHeader, - CardTitle, -} from "@/components/__legacy__/ui/card"; -import { - IconRefresh, - IconSquare, - IconCircleAlert, -} from "@/components/__legacy__/ui/icons"; -import { Input } from "@/components/__legacy__/ui/input"; -import LoadingBox from "@/components/__legacy__/ui/loading"; -import { - Tooltip, - TooltipContent, - TooltipProvider, - TooltipTrigger, -} from "@/components/atoms/Tooltip/BaseTooltip"; -import { useToastOnFail } from "@/components/molecules/Toast/use-toast"; - -import { AgentRunStatus, agentRunStatusMap } from "./agent-run-status-chip"; -import useCredits from "@/hooks/useCredits"; -import { AgentRunOutputView } from "./agent-run-output-view"; -import { analytics } from "@/services/analytics"; -import { PendingReviewsList } from "@/components/organisms/PendingReviewsList/PendingReviewsList"; -import { usePendingReviewsForExecution } from "@/hooks/usePendingReviews"; - -export function AgentRunDetailsView({ - agent, - graph, - run, - agentActions, - onRun, - doDeleteRun, - doCreatePresetFromRun, -}: { - agent: LibraryAgent; - graph: Graph; - run: GraphExecution | GraphExecutionMeta; - agentActions: ButtonAction[]; - onRun: (runID: GraphExecutionID) => void; - doDeleteRun: () => void; - doCreatePresetFromRun: () => void; -}): React.ReactNode { - const api = useBackendAPI(); - const { formatCredits } = useCredits(); - - const runStatus: AgentRunStatus = useMemo( - () => agentRunStatusMap[run.status], - [run], - ); - - const { - pendingReviews, - isLoading: reviewsLoading, - refetch: refetchReviews, - } = usePendingReviewsForExecution(run.id); - - const toastOnFail = useToastOnFail(); - - // Refetch pending reviews when execution status changes to REVIEW - useEffect(() => { - if (runStatus === "review" && run.id) { - refetchReviews(); - } - }, [runStatus, run.id, refetchReviews]); - - const infoStats: { label: string; value: React.ReactNode }[] = useMemo(() => { - if (!run) return []; - return [ - { - label: "Status", - value: runStatus.charAt(0).toUpperCase() + runStatus.slice(1), - }, - { - label: "Started", - value: run.started_at - ? `${formatDistanceToNow(run.started_at, { addSuffix: true })}, ${format(run.started_at, "HH:mm")}` - : "—", - }, - ...(run.stats - ? [ - { - label: "Duration", - value: formatDistanceStrict(0, run.stats.duration * 1000), - }, - { label: "Steps", value: run.stats.node_exec_count }, - { label: "Cost", value: formatCredits(run.stats.cost) }, - ] - : []), - ]; - }, [run, runStatus, formatCredits]); - - const agentRunInputs: - | Record< - string, - { - title?: string; - /* type: BlockIOSubType; */ - value: string | number | undefined; - } - > - | undefined = useMemo(() => { - if (!run.inputs) return undefined; - // TODO: show (link to) preset - https://github.com/Significant-Gravitas/AutoGPT/issues/9168 - - // Add type info from agent input schema - return Object.fromEntries( - Object.entries(run.inputs).map(([k, v]) => [ - k, - { - title: graph.input_schema.properties[k]?.title, - // type: graph.input_schema.properties[k].type, // TODO: implement typed graph inputs - value: typeof v == "object" ? JSON.stringify(v, undefined, 2) : v, - }, - ]), - ); - }, [graph, run]); - - const runAgain = useCallback(() => { - if ( - !run.inputs || - !(graph.credentials_input_schema?.required ?? []).every( - (k) => k in (run.credential_inputs ?? {}), - ) - ) - return; - - if (run.preset_id) { - return api - .executeLibraryAgentPreset( - run.preset_id, - run.inputs!, - run.credential_inputs!, - ) - .then(({ id }) => { - analytics.sendDatafastEvent("run_agent", { - name: graph.name, - id: graph.id, - }); - onRun(id); - }) - .catch(toastOnFail("execute agent preset")); - } - - return api - .executeGraph( - graph.id, - graph.version, - run.inputs!, - run.credential_inputs!, - "library", - ) - .then(({ id }) => { - analytics.sendDatafastEvent("run_agent", { - name: graph.name, - id: graph.id, - }); - onRun(id); - }) - .catch(toastOnFail("execute agent")); - }, [api, graph, run, onRun, toastOnFail]); - - const stopRun = useCallback( - () => api.stopGraphExecution(graph.id, run.id), - [api, graph.id, run.id], - ); - - const agentRunOutputs: - | Record< - string, - { - title?: string; - /* type: BlockIOSubType; */ - values: Array; - } - > - | null - | undefined = useMemo(() => { - if (!("outputs" in run)) return undefined; - if (!["running", "success", "failed", "stopped"].includes(runStatus)) - return null; - - // Add type info from agent input schema - return Object.fromEntries( - Object.entries(run.outputs).map(([k, vv]) => [ - k, - { - title: graph.output_schema.properties[k].title, - /* type: agent.output_schema.properties[k].type */ - values: vv.map((v) => - typeof v == "object" ? JSON.stringify(v, undefined, 2) : v, - ), - }, - ]), - ); - }, [graph, run, runStatus]); - - const runActions: ButtonAction[] = useMemo( - () => [ - ...(["running", "queued"].includes(runStatus) - ? ([ - { - label: ( - <> - - Stop run - - ), - variant: "secondary", - callback: stopRun, - }, - ] satisfies ButtonAction[]) - : []), - ...(["success", "failed", "stopped"].includes(runStatus) && - !graph.has_external_trigger && - (graph.credentials_input_schema?.required ?? []).every( - (k) => k in (run.credential_inputs ?? {}), - ) - ? [ - { - label: ( - <> - - Run again - - ), - callback: runAgain, - dataTestId: "run-again-button", - }, - ] - : []), - ...(agent.can_access_graph - ? [ - { - label: "Open run in builder", - href: `/build?flowID=${run.graph_id}&flowVersion=${run.graph_version}&flowExecutionID=${run.id}`, - }, - ] - : []), - { label: "Create preset from run", callback: doCreatePresetFromRun }, - { label: "Delete run", variant: "secondary", callback: doDeleteRun }, - ], - [ - runStatus, - runAgain, - stopRun, - doDeleteRun, - doCreatePresetFromRun, - graph.has_external_trigger, - graph.credentials_input_schema?.required, - agent.can_access_graph, - run.graph_id, - run.graph_version, - run.id, - ], - ); - - return ( -
-
- - - Info - - - -
- {infoStats.map(({ label, value }) => ( -
-

{label}

-

{value}

-
- ))} -
- {run.status === "FAILED" && ( -
-

- Error:{" "} - {run.stats?.error || - "The execution failed due to an internal error. You can re-run the agent to retry."} -

-
- )} -
-
- - {/* Smart Agent Execution Summary */} - {run.stats?.activity_status && ( - - - - Task Summary - - - - - - -

- This AI-generated summary describes how the agent - handled your task. It’s an experimental feature and may - occasionally be inaccurate. -

-
-
-
-
-
- -

- {run.stats.activity_status} -

- - {/* Correctness Score */} - {typeof run.stats.correctness_score === "number" && ( -
-
- - Success Estimate: - -
-
-
= 0.8 - ? "bg-green-500" - : run.stats.correctness_score >= 0.6 - ? "bg-yellow-500" - : run.stats.correctness_score >= 0.4 - ? "bg-orange-500" - : "bg-red-500" - }`} - style={{ - width: `${Math.round(run.stats.correctness_score * 100)}%`, - }} - /> -
- - {Math.round(run.stats.correctness_score * 100)}% - -
-
- - - - - - -

- AI-generated estimate of how well this execution - achieved its intended purpose. This score indicates - {run.stats.correctness_score >= 0.8 - ? " the agent was highly successful." - : run.stats.correctness_score >= 0.6 - ? " the agent was mostly successful with minor issues." - : run.stats.correctness_score >= 0.4 - ? " the agent was partially successful with some gaps." - : " the agent had limited success with significant issues."} -

-
-
-
-
- )} - - - )} - - {agentRunOutputs !== null && ( - - )} - - {/* Pending Reviews Section */} - {runStatus === "review" && ( - - - - Pending Reviews ({pendingReviews.length}) - - - - {reviewsLoading ? ( - - ) : pendingReviews.length > 0 ? ( - - ) : ( -
- No pending reviews for this execution -
- )} -
-
- )} - - - - Input - - - {agentRunInputs !== undefined ? ( - Object.entries(agentRunInputs).map(([key, { title, value }]) => ( -
- - -
- )) - ) : ( - - )} -
-
-
- - {/* Run / Agent Actions */} - -
- ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-output-view.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-output-view.tsx deleted file mode 100644 index 668ac2e215..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-output-view.tsx +++ /dev/null @@ -1,178 +0,0 @@ -"use client"; - -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; -import React, { useMemo } from "react"; - -import { - Card, - CardContent, - CardHeader, - CardTitle, -} from "@/components/__legacy__/ui/card"; - -import LoadingBox from "@/components/__legacy__/ui/loading"; -import type { OutputMetadata } from "../../../../../../../../components/contextual/OutputRenderers"; -import { - globalRegistry, - OutputActions, - OutputItem, -} from "../../../../../../../../components/contextual/OutputRenderers"; - -export function AgentRunOutputView({ - agentRunOutputs, -}: { - agentRunOutputs: - | Record< - string, - { - title?: string; - /* type: BlockIOSubType; */ - values: Array; - } - > - | undefined; -}) { - const enableEnhancedOutputHandling = useGetFlag( - Flag.ENABLE_ENHANCED_OUTPUT_HANDLING, - ); - - // Prepare items for the renderer system - const outputItems = useMemo(() => { - if (!agentRunOutputs) return []; - - const items: Array<{ - key: string; - label: string; - value: unknown; - metadata?: OutputMetadata; - renderer: any; - }> = []; - - Object.entries(agentRunOutputs).forEach(([key, { title, values }]) => { - values.forEach((value, index) => { - // Enhanced metadata extraction - const metadata: OutputMetadata = {}; - - // Type guard to safely access properties - if ( - typeof value === "object" && - value !== null && - !React.isValidElement(value) - ) { - const objValue = value as any; - if (objValue.type) metadata.type = objValue.type; - if (objValue.mimeType) metadata.mimeType = objValue.mimeType; - if (objValue.filename) metadata.filename = objValue.filename; - } - - const renderer = globalRegistry.getRenderer(value, metadata); - if (renderer) { - items.push({ - key: `${key}-${index}`, - label: index === 0 ? title || key : "", - value, - metadata, - renderer, - }); - } else { - const textRenderer = globalRegistry - .getAllRenderers() - .find((r) => r.name === "TextRenderer"); - if (textRenderer) { - items.push({ - key: `${key}-${index}`, - label: index === 0 ? title || key : "", - value: JSON.stringify(value, null, 2), - metadata, - renderer: textRenderer, - }); - } - } - }); - }); - - return items; - }, [agentRunOutputs]); - - return ( - <> - {enableEnhancedOutputHandling ? ( - - -
- Output - {outputItems.length > 0 && ( - ({ - value: item.value, - metadata: item.metadata, - renderer: item.renderer, - }))} - /> - )} -
-
- - - {agentRunOutputs !== undefined ? ( - outputItems.length > 0 ? ( - outputItems.map((item) => ( - - )) - ) : ( -

- No outputs to display -

- ) - ) : ( - - )} -
-
- ) : ( - - - Output - - - - {agentRunOutputs !== undefined ? ( - Object.entries(agentRunOutputs).map( - ([key, { title, values }]) => ( -
- - {values.map((value, i) => ( -

- {value} -

- ))} - {/* TODO: pretty type-dependent rendering */} -
- ), - ) - ) : ( - - )} -
-
- )} - - ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-status-chip.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-status-chip.tsx deleted file mode 100644 index 58f1ee8381..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-status-chip.tsx +++ /dev/null @@ -1,68 +0,0 @@ -import React from "react"; - -import { Badge } from "@/components/__legacy__/ui/badge"; - -import { GraphExecutionMeta } from "@/lib/autogpt-server-api/types"; - -export type AgentRunStatus = - | "success" - | "failed" - | "queued" - | "running" - | "stopped" - | "scheduled" - | "draft" - | "review"; - -export const agentRunStatusMap: Record< - GraphExecutionMeta["status"], - AgentRunStatus -> = { - INCOMPLETE: "draft", - COMPLETED: "success", - FAILED: "failed", - QUEUED: "queued", - RUNNING: "running", - TERMINATED: "stopped", - REVIEW: "review", -}; - -const statusData: Record< - AgentRunStatus, - { label: string; variant: keyof typeof statusStyles } -> = { - success: { label: "Success", variant: "success" }, - running: { label: "Running", variant: "info" }, - failed: { label: "Failed", variant: "destructive" }, - queued: { label: "Queued", variant: "warning" }, - draft: { label: "Draft", variant: "secondary" }, - stopped: { label: "Stopped", variant: "secondary" }, - scheduled: { label: "Scheduled", variant: "secondary" }, - review: { label: "In Review", variant: "warning" }, -}; - -const statusStyles = { - success: - "bg-green-100 text-green-800 hover:bg-green-100 hover:text-green-800", - destructive: "bg-red-100 text-red-800 hover:bg-red-100 hover:text-red-800", - warning: - "bg-yellow-100 text-yellow-800 hover:bg-yellow-100 hover:text-yellow-800", - info: "bg-blue-100 text-blue-800 hover:bg-blue-100 hover:text-blue-800", - secondary: - "bg-slate-100 text-slate-800 hover:bg-slate-100 hover:text-slate-800", -}; - -export function AgentRunStatusChip({ - status, -}: { - status: AgentRunStatus; -}): React.ReactElement { - return ( - - {statusData[status]?.label} - - ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-summary-card.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-summary-card.tsx deleted file mode 100644 index 6f7d7865bc..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-summary-card.tsx +++ /dev/null @@ -1,130 +0,0 @@ -import React from "react"; -import { formatDistanceToNow, isPast } from "date-fns"; - -import { cn } from "@/lib/utils"; - -import { Link2Icon, Link2OffIcon, MoreVertical } from "lucide-react"; -import { Card, CardContent } from "@/components/__legacy__/ui/card"; -import { Button } from "@/components/__legacy__/ui/button"; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuTrigger, -} from "@/components/__legacy__/ui/dropdown-menu"; - -import { AgentStatus, AgentStatusChip } from "./agent-status-chip"; -import { AgentRunStatus, AgentRunStatusChip } from "./agent-run-status-chip"; -import { PushPinSimpleIcon } from "@phosphor-icons/react"; - -export type AgentRunSummaryProps = ( - | { - type: "run"; - status: AgentRunStatus; - } - | { - type: "preset"; - status?: undefined; - } - | { - type: "preset.triggered"; - status: AgentStatus; - } - | { - type: "schedule"; - status: "scheduled"; - } -) & { - title: string; - timestamp?: number | Date; - selected?: boolean; - onClick?: () => void; - // onRename: () => void; - onDelete: () => void; - onPinAsPreset?: () => void; - className?: string; -}; - -export function AgentRunSummaryCard({ - type, - status, - title, - timestamp, - selected = false, - onClick, - // onRename, - onDelete, - onPinAsPreset, - className, -}: AgentRunSummaryProps): React.ReactElement { - return ( - - - {(type == "run" || type == "schedule") && ( - - )} - {type == "preset" && ( -
- Preset -
- )} - {type == "preset.triggered" && ( -
- - -
- {status == "inactive" ? ( - - ) : ( - - )}{" "} - Trigger -
-
- )} - -
-

- {title} -

- - - - - - - {onPinAsPreset && ( - - Pin as a preset - - )} - - {/* Rename */} - - Delete - - -
- - {timestamp && ( -

- {isPast(timestamp) ? "Ran" : "Runs in"}{" "} - {formatDistanceToNow(timestamp, { addSuffix: true })} -

- )} -
-
- ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-runs-selector-list.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-runs-selector-list.tsx deleted file mode 100644 index 49d93b4319..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-runs-selector-list.tsx +++ /dev/null @@ -1,237 +0,0 @@ -"use client"; -import { Plus } from "lucide-react"; -import React, { useEffect, useState } from "react"; - -import { - GraphExecutionID, - GraphExecutionMeta, - LibraryAgent, - LibraryAgentPreset, - LibraryAgentPresetID, - Schedule, - ScheduleID, -} from "@/lib/autogpt-server-api"; -import { cn } from "@/lib/utils"; - -import { Badge } from "@/components/__legacy__/ui/badge"; -import { Button } from "@/components/atoms/Button/Button"; -import LoadingBox, { LoadingSpinner } from "@/components/__legacy__/ui/loading"; -import { Separator } from "@/components/__legacy__/ui/separator"; -import { ScrollArea } from "@/components/__legacy__/ui/scroll-area"; -import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll"; -import { AgentRunsQuery } from "../use-agent-runs"; -import { agentRunStatusMap } from "./agent-run-status-chip"; -import { AgentRunSummaryCard } from "./agent-run-summary-card"; - -interface AgentRunsSelectorListProps { - agent: LibraryAgent; - agentRunsQuery: AgentRunsQuery; - agentPresets: LibraryAgentPreset[]; - schedules: Schedule[]; - selectedView: { type: "run" | "preset" | "schedule"; id?: string }; - allowDraftNewRun?: boolean; - onSelectRun: (id: GraphExecutionID) => void; - onSelectPreset: (preset: LibraryAgentPresetID) => void; - onSelectSchedule: (id: ScheduleID) => void; - onSelectDraftNewRun: () => void; - doDeleteRun: (id: GraphExecutionMeta) => void; - doDeletePreset: (id: LibraryAgentPresetID) => void; - doDeleteSchedule: (id: ScheduleID) => void; - doCreatePresetFromRun?: (id: GraphExecutionID) => void; - className?: string; -} - -export function AgentRunsSelectorList({ - agent, - agentRunsQuery: { - agentRuns, - agentRunCount, - agentRunsLoading, - hasMoreRuns, - fetchMoreRuns, - isFetchingMoreRuns, - }, - agentPresets, - schedules, - selectedView, - allowDraftNewRun = true, - onSelectRun, - onSelectPreset, - onSelectSchedule, - onSelectDraftNewRun, - doDeleteRun, - doDeletePreset, - doDeleteSchedule, - doCreatePresetFromRun, - className, -}: AgentRunsSelectorListProps): React.ReactElement { - const [activeListTab, setActiveListTab] = useState<"runs" | "scheduled">( - "runs", - ); - - useEffect(() => { - if (selectedView.type === "schedule") { - setActiveListTab("scheduled"); - } else { - setActiveListTab("runs"); - } - }, [selectedView]); - - const listItemClasses = "h-28 w-72 lg:w-full lg:h-32"; - - return ( - - ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-schedule-details-view.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-schedule-details-view.tsx deleted file mode 100644 index 30b0a82e65..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-schedule-details-view.tsx +++ /dev/null @@ -1,180 +0,0 @@ -"use client"; -import React, { useCallback, useMemo } from "react"; - -import { - Graph, - GraphExecutionID, - Schedule, - ScheduleID, -} from "@/lib/autogpt-server-api"; -import { useBackendAPI } from "@/lib/autogpt-server-api/context"; - -import ActionButtonGroup from "@/components/__legacy__/action-button-group"; -import type { ButtonAction } from "@/components/__legacy__/types"; -import { - Card, - CardContent, - CardHeader, - CardTitle, -} from "@/components/__legacy__/ui/card"; -import { IconCross } from "@/components/__legacy__/ui/icons"; -import { Input } from "@/components/__legacy__/ui/input"; -import LoadingBox from "@/components/__legacy__/ui/loading"; -import { useToastOnFail } from "@/components/molecules/Toast/use-toast"; -import { humanizeCronExpression } from "@/lib/cron-expression-utils"; -import { formatScheduleTime } from "@/lib/timezone-utils"; -import { useUserTimezone } from "@/lib/hooks/useUserTimezone"; -import { PlayIcon } from "lucide-react"; - -import { AgentRunStatus } from "./agent-run-status-chip"; - -export function AgentScheduleDetailsView({ - graph, - schedule, - agentActions, - onForcedRun, - doDeleteSchedule, -}: { - graph: Graph; - schedule: Schedule; - agentActions: ButtonAction[]; - onForcedRun: (runID: GraphExecutionID) => void; - doDeleteSchedule: (scheduleID: ScheduleID) => void; -}): React.ReactNode { - const api = useBackendAPI(); - - const selectedRunStatus: AgentRunStatus = "scheduled"; - - const toastOnFail = useToastOnFail(); - - // Get user's timezone for displaying schedule times - const userTimezone = useUserTimezone(); - - const infoStats: { label: string; value: React.ReactNode }[] = useMemo(() => { - return [ - { - label: "Status", - value: - selectedRunStatus.charAt(0).toUpperCase() + - selectedRunStatus.slice(1), - }, - { - label: "Schedule", - value: humanizeCronExpression(schedule.cron), - }, - { - label: "Next run", - value: formatScheduleTime(schedule.next_run_time, userTimezone), - }, - ]; - }, [schedule, selectedRunStatus, userTimezone]); - - const agentRunInputs: Record< - string, - { title?: string; /* type: BlockIOSubType; */ value: any } - > = useMemo(() => { - // TODO: show (link to) preset - https://github.com/Significant-Gravitas/AutoGPT/issues/9168 - - // Add type info from agent input schema - return Object.fromEntries( - Object.entries(schedule.input_data).map(([k, v]) => [ - k, - { - title: graph.input_schema.properties[k].title, - /* TODO: type: agent.input_schema.properties[k].type */ - value: v, - }, - ]), - ); - }, [graph, schedule]); - - const runNow = useCallback( - () => - api - .executeGraph( - graph.id, - graph.version, - schedule.input_data, - schedule.input_credentials, - "library", - ) - .then((run) => onForcedRun(run.id)) - .catch(toastOnFail("execute agent")), - [api, graph, schedule, onForcedRun, toastOnFail], - ); - - const runActions: ButtonAction[] = useMemo( - () => [ - { - label: ( - <> - - Run now - - ), - callback: runNow, - }, - { - label: ( - <> - - Delete schedule - - ), - callback: () => doDeleteSchedule(schedule.id), - variant: "destructive", - }, - ], - [runNow], - ); - - return ( -
-
- - - Info - - - -
- {infoStats.map(({ label, value }) => ( -
-

{label}

-

{value}

-
- ))} -
-
-
- - - - Input - - - {agentRunInputs !== undefined ? ( - Object.entries(agentRunInputs).map(([key, { title, value }]) => ( -
- - -
- )) - ) : ( - - )} -
-
-
- - {/* Run / Agent Actions */} - -
- ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/create-preset-dialog.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/create-preset-dialog.tsx deleted file mode 100644 index 2ca64d5ec5..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/create-preset-dialog.tsx +++ /dev/null @@ -1,100 +0,0 @@ -"use client"; - -import React, { useState } from "react"; -import { Button } from "@/components/__legacy__/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, -} from "@/components/__legacy__/ui/dialog"; -import { Input } from "@/components/__legacy__/ui/input"; -import { Textarea } from "@/components/__legacy__/ui/textarea"; - -interface CreatePresetDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - onConfirm: (name: string, description: string) => Promise | void; -} - -export function CreatePresetDialog({ - open, - onOpenChange, - onConfirm, -}: CreatePresetDialogProps) { - const [name, setName] = useState(""); - const [description, setDescription] = useState(""); - - const handleSubmit = async () => { - if (name.trim()) { - await onConfirm(name.trim(), description.trim()); - setName(""); - setDescription(""); - onOpenChange(false); - } - }; - - const handleCancel = () => { - setName(""); - setDescription(""); - onOpenChange(false); - }; - - const handleKeyDown = (e: React.KeyboardEvent) => { - if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) { - e.preventDefault(); - handleSubmit(); - } - }; - - return ( - - - - Create Preset - - Give your preset a name and description to help identify it later. - - -
-
- - setName(e.target.value)} - onKeyDown={handleKeyDown} - autoFocus - /> -
-
- -