diff --git a/.github/scripts/detect_overlaps.py b/.github/scripts/detect_overlaps.py new file mode 100644 index 0000000000..1f9f4be7cf --- /dev/null +++ b/.github/scripts/detect_overlaps.py @@ -0,0 +1,1229 @@ +#!/usr/bin/env python3 +""" +PR Overlap Detection Tool + +Detects potential merge conflicts between a given PR and other open PRs +by checking for file overlap, line overlap, and actual merge conflicts. +""" + +import json +import os +import re +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from typing import Optional + + +# ============================================================================= +# MAIN ENTRY POINT +# ============================================================================= + +def main(): + """Main entry point for PR overlap detection.""" + import argparse + + parser = argparse.ArgumentParser(description="Detect PR overlaps and potential merge conflicts") + parser.add_argument("pr_number", type=int, help="PR number to check") + parser.add_argument("--base", default=None, help="Base branch (default: auto-detect from PR)") + parser.add_argument("--skip-merge-test", action="store_true", help="Skip actual merge conflict testing") + parser.add_argument("--discord-webhook", default=os.environ.get("DISCORD_WEBHOOK_URL"), help="Discord webhook URL for notifications") + parser.add_argument("--dry-run", action="store_true", help="Don't post comments, just print") + + args = parser.parse_args() + + owner, repo = get_repo_info() + print(f"Checking PR #{args.pr_number} in {owner}/{repo}") + + # Get current PR info + current_pr = fetch_pr_details(args.pr_number) + base_branch = args.base or current_pr.base_ref + + print(f"PR #{current_pr.number}: {current_pr.title}") + print(f"Base branch: {base_branch}") + print(f"Files changed: {len(current_pr.files)}") + + # Find overlapping PRs + overlaps, all_changes = find_overlapping_prs( + owner, repo, base_branch, current_pr, args.pr_number, args.skip_merge_test + ) + + if not overlaps: + print("No overlaps detected!") + return + + # Generate and post report + comment = format_comment(overlaps, args.pr_number, current_pr.changed_ranges, all_changes) + + if args.dry_run: + print("\n" + "="*60) + print("COMMENT PREVIEW:") + print("="*60) + print(comment) + else: + if comment: + post_or_update_comment(args.pr_number, comment) + print("Posted comment to PR") + + if args.discord_webhook: + send_discord_notification(args.discord_webhook, current_pr, overlaps) + + # Report results and exit + report_results(overlaps) + + +# ============================================================================= +# HIGH-LEVEL WORKFLOW FUNCTIONS +# ============================================================================= + +def fetch_pr_details(pr_number: int) -> "PullRequest": + """Fetch details for a specific PR including its diff.""" + result = run_gh(["pr", "view", str(pr_number), "--json", "number,title,url,author,headRefName,baseRefName,files"]) + data = json.loads(result.stdout) + + pr = PullRequest( + number=data["number"], + title=data["title"], + author=data["author"]["login"] if data.get("author") else "unknown", + url=data["url"], + head_ref=data["headRefName"], + base_ref=data["baseRefName"], + files=[f["path"] for f in data["files"]], + changed_ranges={} + ) + + # Get detailed diff + diff = get_pr_diff(pr_number) + pr.changed_ranges = parse_diff_ranges(diff) + + return pr + + +def find_overlapping_prs( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + current_pr_number: int, + skip_merge_test: bool +) -> tuple[list["Overlap"], dict[int, dict[str, "ChangedFile"]]]: + """Find all PRs that overlap with the current PR.""" + # Query other open PRs + all_prs = query_open_prs(owner, repo, base_branch) + other_prs = [p for p in all_prs if p["number"] != current_pr_number] + + print(f"Found {len(other_prs)} other open PRs targeting {base_branch}") + + # Find file overlaps (excluding ignored files, filtering by age) + candidates = find_file_overlap_candidates(current_pr.files, other_prs) + + print(f"Found {len(candidates)} PRs with file overlap (excluding ignored files)") + + if not candidates: + return [], {} + + # First pass: analyze line overlaps (no merge testing yet) + overlaps = [] + all_changes = {} + prs_needing_merge_test = [] + + for pr_data, shared_files in candidates: + overlap, pr_changes = analyze_pr_overlap( + owner, repo, base_branch, current_pr, pr_data, shared_files, + skip_merge_test=True # Always skip in first pass + ) + if overlap: + overlaps.append(overlap) + all_changes[pr_data["number"]] = pr_changes + # Track PRs that need merge testing + if overlap.line_overlaps and not skip_merge_test: + prs_needing_merge_test.append(overlap) + + # Second pass: batch merge testing with shared clone + if prs_needing_merge_test: + run_batch_merge_tests(owner, repo, base_branch, current_pr, prs_needing_merge_test) + + return overlaps, all_changes + + +def run_batch_merge_tests( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + overlaps: list["Overlap"] +): + """Run merge tests for multiple PRs using a shared clone.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone once + if not clone_repo(owner, repo, base_branch, tmpdir): + return + + configure_git(tmpdir) + + # Fetch current PR branch once + result = run_git(["fetch", "origin", f"pull/{current_pr.number}/head:pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch current PR #{current_pr.number}", file=sys.stderr) + return + + for overlap in overlaps: + other_pr = overlap.pr_b if overlap.pr_a.number == current_pr.number else overlap.pr_a + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + + # Clean up any in-progress merge from previous iteration + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + # Reset to base branch + run_git(["checkout", base_branch], cwd=tmpdir, check=False) + run_git(["reset", "--hard", f"origin/{base_branch}"], cwd=tmpdir, check=False) + run_git(["clean", "-fdx"], cwd=tmpdir, check=False) + + # Fetch the other PR branch + result = run_git(["fetch", "origin", f"pull/{other_pr.number}/head:pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{other_pr.number}: {result.stderr.strip()}", file=sys.stderr) + continue + + # Try merging current PR first + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Current PR conflicts with base + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'pr_a_conflicts_base' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + continue + + # Commit and try merging other PR + run_git(["commit", "-m", f"Merge PR #{current_pr.number}"], cwd=tmpdir, check=False) + + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Conflict between PRs + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'conflict' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + +def analyze_pr_overlap( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + other_pr_data: dict, + shared_files: list[str], + skip_merge_test: bool +) -> tuple[Optional["Overlap"], dict[str, "ChangedFile"]]: + """Analyze overlap between current PR and another PR.""" + # Filter out ignored files + non_ignored_shared = [f for f in shared_files if not should_ignore_file(f)] + if not non_ignored_shared: + return None, {} + + other_pr = PullRequest( + number=other_pr_data["number"], + title=other_pr_data["title"], + author=other_pr_data["author"], + url=other_pr_data["url"], + head_ref=other_pr_data["head_ref"], + base_ref=other_pr_data["base_ref"], + files=other_pr_data["files"], + changed_ranges={}, + updated_at=other_pr_data.get("updated_at") + ) + + # Get diff for other PR + other_diff = get_pr_diff(other_pr.number) + other_pr.changed_ranges = parse_diff_ranges(other_diff) + + # Check line overlaps + line_overlaps = find_line_overlaps( + current_pr.changed_ranges, + other_pr.changed_ranges, + shared_files + ) + + overlap = Overlap( + pr_a=current_pr, + pr_b=other_pr, + overlapping_files=non_ignored_shared, + line_overlaps=line_overlaps + ) + + # Test for actual merge conflicts if we have line overlaps + if line_overlaps and not skip_merge_test: + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + has_conflict, conflict_files, conflict_details, error_type = test_merge_conflict( + owner, repo, base_branch, current_pr, other_pr + ) + overlap.has_merge_conflict = has_conflict + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = error_type + + return overlap, other_pr.changed_ranges + + +def find_file_overlap_candidates( + current_files: list[str], + other_prs: list[dict], + max_age_days: int = 14 +) -> list[tuple[dict, list[str]]]: + """Find PRs that share files with the current PR.""" + from datetime import datetime, timezone, timedelta + + current_files_set = set(f for f in current_files if not should_ignore_file(f)) + candidates = [] + cutoff_date = datetime.now(timezone.utc) - timedelta(days=max_age_days) + + for pr_data in other_prs: + # Filter out PRs older than max_age_days + updated_at = pr_data.get("updated_at") + if updated_at: + try: + pr_date = datetime.fromisoformat(updated_at.replace('Z', '+00:00')) + if pr_date < cutoff_date: + continue # Skip old PRs + except Exception as e: + # If we can't parse date, include the PR (safe fallback) + print(f"Warning: Could not parse date for PR: {e}", file=sys.stderr) + + other_files = set(f for f in pr_data["files"] if not should_ignore_file(f)) + shared = current_files_set & other_files + + if shared: + candidates.append((pr_data, list(shared))) + + return candidates + + +def report_results(overlaps: list["Overlap"]): + """Report results (informational only, always exits 0).""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if conflicts: + print(f"\n⚠️ Found {len(conflicts)} merge conflict(s)") + + line_overlap_count = len([o for o in overlaps if o.line_overlaps]) + if line_overlap_count: + print(f"\n⚠️ Found {line_overlap_count} PR(s) with line overlap") + + print("\n✅ Done") + # Always exit 0 - this check is informational, not a merge blocker + + +# ============================================================================= +# COMMENT FORMATTING +# ============================================================================= + +def format_comment( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict[str, "ChangedFile"], + all_changes: dict[int, dict[str, "ChangedFile"]] +) -> str: + """Format the overlap report as a PR comment.""" + if not overlaps: + return "" + + lines = ["## 🔍 PR Overlap Detection"] + lines.append("") + lines.append("This check compares your PR against all other open PRs targeting the same branch to detect potential merge conflicts early.") + lines.append("") + + # Check if current PR conflicts with base branch + format_base_conflicts(overlaps, lines) + + # Classify and sort overlaps + classified = classify_all_overlaps(overlaps, current_pr, changes_current, all_changes) + + # Group by risk + conflicts = [(o, r) for o, r in classified if r == 'conflict'] + medium_risk = [(o, r) for o, r in classified if r == 'medium'] + low_risk = [(o, r) for o, r in classified if r == 'low'] + + # Format each section + format_conflicts_section(conflicts, current_pr, lines) + format_medium_risk_section(medium_risk, current_pr, changes_current, all_changes, lines) + format_low_risk_section(low_risk, current_pr, lines) + + # Summary + total = len(overlaps) + lines.append(f"\n**Summary:** {len(conflicts)} conflict(s), {len(medium_risk)} medium risk, {len(low_risk)} low risk (out of {total} PRs with file overlap)") + lines.append("\n---\n*Auto-generated on push. Ignores: `openapi.json`, lock files.*") + + return "\n".join(lines) + + +def format_base_conflicts(overlaps: list["Overlap"], lines: list[str]): + """Format base branch conflicts section.""" + base_conflicts = [o for o in overlaps if o.conflict_type == 'pr_a_conflicts_base'] + if base_conflicts: + lines.append("### ⚠️ This PR has conflicts with the base branch\n") + lines.append("Conflicts will need to be resolved before merging:\n") + first = base_conflicts[0] + for f in first.conflict_files[:10]: + lines.append(f"- `{f}`") + if len(first.conflict_files) > 10: + lines.append(f"- ... and {len(first.conflict_files) - 10} more files") + lines.append("\n") + + +def format_conflicts_section(conflicts: list[tuple], current_pr: int, lines: list[str]): + """Format the merge conflicts section.""" + pr_conflicts = [(o, r) for o, r in conflicts if o.conflict_type != 'pr_a_conflicts_base'] + + if not pr_conflicts: + return + + lines.append("### 🔴 Merge Conflicts Detected") + lines.append("") + lines.append("The following PRs have been tested and **will have merge conflicts** if merged after this PR. Consider coordinating with the authors.") + lines.append("") + + for o, _ in pr_conflicts: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + format_pr_entry(other, lines) + format_conflict_details(o, lines) + lines.append("") + + +def format_medium_risk_section( + medium_risk: list[tuple], + current_pr: int, + changes_current: dict, + all_changes: dict, + lines: list[str] +): + """Format the medium risk section.""" + if not medium_risk: + return + + lines.append("### 🟡 Medium Risk — Some Line Overlap\n") + lines.append("These PRs have some overlapping changes:\n") + + for o, _ in medium_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other.number, {}) + format_pr_entry(other, lines) + + # Note if rename is involved + for file_path in o.overlapping_files: + file_a = changes_current.get(file_path) + file_b = other_changes.get(file_path) + if (file_a and file_a.is_rename) or (file_b and file_b.is_rename): + lines.append(f" - ⚠️ `{file_path}` is being renamed/moved") + break + + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") + + +def format_low_risk_section(low_risk: list[tuple], current_pr: int, lines: list[str]): + """Format the low risk section.""" + if not low_risk: + return + + lines.append("### 🟢 Low Risk — File Overlap Only\n") + lines.append("
These PRs touch the same files but different sections (click to expand)\n") + + for o, _ in low_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + format_pr_entry(other, lines) + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") # Add blank line between entries + + lines.append("
\n") + + +def format_pr_entry(pr: "PullRequest", lines: list[str]): + """Format a single PR entry line.""" + updated = format_relative_time(pr.updated_at) + updated_str = f" · updated {updated}" if updated else "" + # Just use #number - GitHub auto-renders it with title + lines.append(f"- #{pr.number} ({pr.author}{updated_str})") + + +def format_conflict_details(overlap: "Overlap", lines: list[str]): + """Format conflict details for a PR.""" + if overlap.conflict_details: + all_paths = [d.path for d in overlap.conflict_details] + common_prefix = find_common_prefix(all_paths) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for detail in overlap.conflict_details: + display_path = detail.path[len(common_prefix):] if common_prefix else detail.path + size_str = format_conflict_size(detail) + lines.append(f" - `{display_path}`{size_str}") + elif overlap.conflict_files: + common_prefix = find_common_prefix(overlap.conflict_files) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for f in overlap.conflict_files: + display_path = f[len(common_prefix):] if common_prefix else f + lines.append(f" - `{display_path}`") + + +def format_conflict_size(detail: "ConflictInfo") -> str: + """Format conflict size string for a file.""" + if detail.conflict_count > 0: + return f" ({detail.conflict_count} conflict{'s' if detail.conflict_count > 1 else ''}, ~{detail.conflict_lines} lines)" + elif detail.conflict_type != 'content': + type_labels = { + 'both_added': 'added in both', + 'both_deleted': 'deleted in both', + 'deleted_by_us': 'deleted here, modified there', + 'deleted_by_them': 'modified here, deleted there', + 'added_by_us': 'added here', + 'added_by_them': 'added there', + } + label = type_labels.get(detail.conflict_type, detail.conflict_type) + return f" ({label})" + return "" + + +def format_line_overlaps(line_overlaps: dict[str, list[tuple]], lines: list[str]): + """Format line overlap details.""" + all_paths = list(line_overlaps.keys()) + common_prefix = find_common_prefix(all_paths) if len(all_paths) > 1 else "" + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for file_path, ranges in line_overlaps.items(): + display_path = file_path[len(common_prefix):] if common_prefix else file_path + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + indent = " " if common_prefix else " " + lines.append(f"{indent}- `{display_path}`: {', '.join(range_strs)}") + + +# ============================================================================= +# OVERLAP ANALYSIS +# ============================================================================= + +def classify_all_overlaps( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict, + all_changes: dict +) -> list[tuple["Overlap", str]]: + """Classify all overlaps by risk level and sort them.""" + classified = [] + for o in overlaps: + other_pr = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other_pr.number, {}) + risk = classify_overlap_risk(o, changes_current, other_changes) + classified.append((o, risk)) + + def sort_key(item): + o, risk = item + risk_order = {'conflict': 0, 'medium': 1, 'low': 2} + # For conflicts, also sort by total conflict lines (descending) + conflict_lines = sum(d.conflict_lines for d in o.conflict_details) if o.conflict_details else 0 + return (risk_order.get(risk, 99), -conflict_lines) + + classified.sort(key=sort_key) + + return classified + + +def classify_overlap_risk( + overlap: "Overlap", + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"] +) -> str: + """Classify the risk level of an overlap.""" + if overlap.has_merge_conflict: + return 'conflict' + + has_rename = any( + (changes_a.get(f) and changes_a[f].is_rename) or + (changes_b.get(f) and changes_b[f].is_rename) + for f in overlap.overlapping_files + ) + + if overlap.line_overlaps: + total_overlap_lines = sum( + end - start + 1 + for ranges in overlap.line_overlaps.values() + for start, end in ranges + ) + + # Medium risk: >20 lines overlap or file rename + if total_overlap_lines > 20 or has_rename: + return 'medium' + else: + return 'low' + + if has_rename: + return 'medium' + + return 'low' + + +def find_line_overlaps( + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"], + shared_files: list[str] +) -> dict[str, list[tuple[int, int]]]: + """Find overlapping line ranges in shared files.""" + overlaps = {} + + for file_path in shared_files: + if should_ignore_file(file_path): + continue + + file_a = changes_a.get(file_path) + file_b = changes_b.get(file_path) + + if not file_a or not file_b: + continue + + # Skip pure renames + if file_a.is_rename and not file_a.additions and not file_a.deletions: + continue + if file_b.is_rename and not file_b.additions and not file_b.deletions: + continue + + # Note: This mixes old-file (deletions) and new-file (additions) line numbers, + # which can cause false positives when PRs insert/remove many lines. + # Acceptable for v1 since the real merge test is the authoritative check. + file_overlaps = find_range_overlaps( + file_a.additions + file_a.deletions, + file_b.additions + file_b.deletions + ) + + if file_overlaps: + overlaps[file_path] = merge_ranges(file_overlaps) + + return overlaps + + +def find_range_overlaps( + ranges_a: list[tuple[int, int]], + ranges_b: list[tuple[int, int]] +) -> list[tuple[int, int]]: + """Find overlapping regions between two sets of ranges.""" + overlaps = [] + for range_a in ranges_a: + for range_b in ranges_b: + if ranges_overlap(range_a, range_b): + overlap_start = max(range_a[0], range_b[0]) + overlap_end = min(range_a[1], range_b[1]) + overlaps.append((overlap_start, overlap_end)) + return overlaps + + +def ranges_overlap(range_a: tuple[int, int], range_b: tuple[int, int]) -> bool: + """Check if two line ranges overlap.""" + return range_a[0] <= range_b[1] and range_b[0] <= range_a[1] + + +def merge_ranges(ranges: list[tuple[int, int]]) -> list[tuple[int, int]]: + """Merge overlapping line ranges.""" + if not ranges: + return [] + + sorted_ranges = sorted(ranges, key=lambda x: x[0]) + merged = [sorted_ranges[0]] + + for current in sorted_ranges[1:]: + last = merged[-1] + if current[0] <= last[1] + 1: + merged[-1] = (last[0], max(last[1], current[1])) + else: + merged.append(current) + + return merged + + +# ============================================================================= +# MERGE CONFLICT TESTING +# ============================================================================= + +def test_merge_conflict( + owner: str, + repo: str, + base_branch: str, + pr_a: "PullRequest", + pr_b: "PullRequest" +) -> tuple[bool, list[str], list["ConflictInfo"], str]: + """Test if merging both PRs would cause a conflict.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone repo + if not clone_repo(owner, repo, base_branch, tmpdir): + return False, [], [], None + + configure_git(tmpdir) + if not fetch_pr_branches(tmpdir, pr_a.number, pr_b.number): + # Fetch failed for one or both PRs - can't test merge + return False, [], [], None + + # Try merging PR A first + conflict_result = try_merge_pr(tmpdir, pr_a.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'pr_a_conflicts_base' + + # Commit and try merging PR B + run_git(["commit", "-m", f"Merge PR #{pr_a.number}"], cwd=tmpdir, check=False) + + conflict_result = try_merge_pr(tmpdir, pr_b.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'conflict' + + return False, [], [], None + + +def clone_repo(owner: str, repo: str, branch: str, tmpdir: str) -> bool: + """Clone the repository.""" + clone_url = f"https://github.com/{owner}/{repo}.git" + result = run_git( + ["clone", "--depth=50", "--branch", branch, clone_url, tmpdir], + check=False + ) + if result.returncode != 0: + print(f"Failed to clone: {result.stderr}", file=sys.stderr) + return False + return True + + +def configure_git(tmpdir: str): + """Configure git for commits.""" + run_git(["config", "user.email", "github-actions[bot]@users.noreply.github.com"], cwd=tmpdir, check=False) + run_git(["config", "user.name", "github-actions[bot]"], cwd=tmpdir, check=False) + + +def fetch_pr_branches(tmpdir: str, pr_a: int, pr_b: int) -> bool: + """Fetch both PR branches. Returns False if any fetch fails.""" + success = True + for pr_num in (pr_a, pr_b): + result = run_git(["fetch", "origin", f"pull/{pr_num}/head:pr-{pr_num}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{pr_num}: {result.stderr.strip()}", file=sys.stderr) + success = False + return success + + +def try_merge_pr(tmpdir: str, pr_number: int) -> Optional[tuple[list[str], list["ConflictInfo"]]]: + """Try to merge a PR. Returns conflict info if conflicts, None if success.""" + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{pr_number}"], cwd=tmpdir, check=False) + + if result.returncode == 0: + return None + + # Conflict detected + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + return conflict_files, conflict_details + + +def extract_conflict_info(tmpdir: str, stderr: str) -> tuple[list[str], list["ConflictInfo"]]: + """Extract conflict information from git status.""" + status_result = run_git(["status", "--porcelain"], cwd=tmpdir, check=False) + + status_types = { + 'UU': 'content', + 'AA': 'both_added', + 'DD': 'both_deleted', + 'DU': 'deleted_by_us', + 'UD': 'deleted_by_them', + 'AU': 'added_by_us', + 'UA': 'added_by_them', + } + + conflict_files = [] + conflict_details = [] + + for line in status_result.stdout.split("\n"): + if len(line) >= 3 and line[0:2] in status_types: + status_code = line[0:2] + file_path = line[3:].strip() + conflict_files.append(file_path) + + info = analyze_conflict_markers(file_path, tmpdir) + info.conflict_type = status_types.get(status_code, 'unknown') + conflict_details.append(info) + + # Fallback to stderr parsing + if not conflict_files and stderr: + for line in stderr.split("\n"): + if "CONFLICT" in line and ":" in line: + parts = line.split(":") + if len(parts) > 1: + file_part = parts[-1].strip() + if file_part and not file_part.startswith("Merge"): + conflict_files.append(file_part) + conflict_details.append(ConflictInfo(path=file_part)) + + return conflict_files, conflict_details + + +def analyze_conflict_markers(file_path: str, cwd: str) -> "ConflictInfo": + """Analyze a conflicted file to count conflict regions and lines.""" + info = ConflictInfo(path=file_path) + + try: + full_path = os.path.join(cwd, file_path) + with open(full_path, 'r', errors='ignore') as f: + content = f.read() + + in_conflict = False + current_conflict_lines = 0 + + for line in content.split('\n'): + if line.startswith('<<<<<<<'): + in_conflict = True + info.conflict_count += 1 + current_conflict_lines = 1 + elif line.startswith('>>>>>>>'): + in_conflict = False + current_conflict_lines += 1 + info.conflict_lines += current_conflict_lines + elif in_conflict: + current_conflict_lines += 1 + except Exception as e: + print(f"Warning: Could not analyze conflict markers in {file_path}: {e}", file=sys.stderr) + + return info + + +# ============================================================================= +# DIFF PARSING +# ============================================================================= + +def parse_diff_ranges(diff: str) -> dict[str, "ChangedFile"]: + """Parse a unified diff and extract changed line ranges per file.""" + files = {} + current_file = None + pending_rename_from = None + is_rename = False + + for line in diff.split("\n"): + # Reset rename state on new file diff header + if line.startswith("diff --git "): + is_rename = False + pending_rename_from = None + elif line.startswith("rename from "): + pending_rename_from = line[12:] + is_rename = True + elif line.startswith("rename to "): + pass # rename target is captured via "+++ b/" line + elif line.startswith("similarity index"): + is_rename = True + elif line.startswith("+++ b/"): + path = line[6:] + current_file = ChangedFile( + path=path, + additions=[], + deletions=[], + is_rename=is_rename, + old_path=pending_rename_from + ) + files[path] = current_file + pending_rename_from = None + is_rename = False + elif line.startswith("--- /dev/null"): + is_rename = False + pending_rename_from = None + elif line.startswith("@@") and current_file: + parse_hunk_header(line, current_file) + + return files + + +def parse_hunk_header(line: str, current_file: "ChangedFile"): + """Parse a diff hunk header and add ranges to the file.""" + match = re.match(r"@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@", line) + if match: + old_start = int(match.group(1)) + old_count = int(match.group(2) or 1) + new_start = int(match.group(3)) + new_count = int(match.group(4) or 1) + + if old_count > 0: + current_file.deletions.append((old_start, old_start + old_count - 1)) + if new_count > 0: + current_file.additions.append((new_start, new_start + new_count - 1)) + + +# ============================================================================= +# GITHUB API +# ============================================================================= + +def get_repo_info() -> tuple[str, str]: + """Get owner and repo name from environment or git.""" + if os.environ.get("GITHUB_REPOSITORY"): + owner, repo = os.environ["GITHUB_REPOSITORY"].split("/") + return owner, repo + + result = run_gh(["repo", "view", "--json", "owner,name"]) + data = json.loads(result.stdout) + return data["owner"]["login"], data["name"] + + +def query_open_prs(owner: str, repo: str, base_branch: str) -> list[dict]: + """Query all open PRs targeting the specified base branch.""" + prs = [] + cursor = None + + while True: + after_clause = f', after: "{cursor}"' if cursor else "" + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequests( + first: 100{after_clause}, + states: OPEN, + baseRefName: "{base_branch}", + orderBy: {{field: UPDATED_AT, direction: DESC}} + ) {{ + totalCount + edges {{ + node {{ + number + title + url + updatedAt + author {{ login }} + headRefName + baseRefName + files(first: 100) {{ + nodes {{ path }} + pageInfo {{ hasNextPage }} + }} + }} + }} + pageInfo {{ + endCursor + hasNextPage + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"]) + data = json.loads(result.stdout) + + if "errors" in data: + print(f"GraphQL errors: {data['errors']}", file=sys.stderr) + sys.exit(1) + + pr_data = data["data"]["repository"]["pullRequests"] + for edge in pr_data["edges"]: + node = edge["node"] + files_data = node["files"] + # Warn if PR has more than 100 files (API limit, we only fetch first 100) + if files_data.get("pageInfo", {}).get("hasNextPage"): + print(f"Warning: PR #{node['number']} has >100 files, overlap detection may be incomplete", file=sys.stderr) + prs.append({ + "number": node["number"], + "title": node["title"], + "url": node["url"], + "updated_at": node.get("updatedAt"), + "author": node["author"]["login"] if node["author"] else "unknown", + "head_ref": node["headRefName"], + "base_ref": node["baseRefName"], + "files": [f["path"] for f in files_data["nodes"]] + }) + + if not pr_data["pageInfo"]["hasNextPage"]: + break + cursor = pr_data["pageInfo"]["endCursor"] + + return prs + + +def get_pr_diff(pr_number: int) -> str: + """Get the diff for a PR.""" + result = run_gh(["pr", "diff", str(pr_number)]) + return result.stdout + + +def post_or_update_comment(pr_number: int, body: str): + """Post a new comment or update existing overlap detection comment.""" + if not body: + return + + marker = "## 🔍 PR Overlap Detection" + + # Find existing comment using GraphQL + owner, repo = get_repo_info() + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequest(number: {pr_number}) {{ + comments(first: 100) {{ + nodes {{ + id + body + author {{ login }} + }} + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"], check=False) + + existing_comment_id = None + if result.returncode == 0: + try: + data = json.loads(result.stdout) + comments = data.get("data", {}).get("repository", {}).get("pullRequest", {}).get("comments", {}).get("nodes", []) + for comment in comments: + if marker in comment.get("body", ""): + existing_comment_id = comment["id"] + break + except Exception as e: + print(f"Warning: Could not search for existing comment: {e}", file=sys.stderr) + + if existing_comment_id: + # Update existing comment using GraphQL mutation + # Use json.dumps for proper escaping of all special characters + escaped_body = json.dumps(body)[1:-1] # Strip outer quotes added by json.dumps + mutation = f''' + mutation {{ + updateIssueComment(input: {{id: "{existing_comment_id}", body: "{escaped_body}"}}) {{ + issueComment {{ id }} + }} + }} + ''' + result = run_gh(["api", "graphql", "-f", f"query={mutation}"], check=False) + if result.returncode == 0: + print(f"Updated existing overlap comment") + else: + # Fallback to posting new comment + print(f"Failed to update comment, posting new one: {result.stderr}", file=sys.stderr) + run_gh(["pr", "comment", str(pr_number), "--body", body]) + else: + # Post new comment + run_gh(["pr", "comment", str(pr_number), "--body", body]) + + +def send_discord_notification(webhook_url: str, pr: "PullRequest", overlaps: list["Overlap"]): + """Send a Discord notification about significant overlaps.""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if not conflicts: + return + + # Discord limits: max 25 fields, max 1024 chars per field value + fields = [] + for o in conflicts[:25]: + other = o.pr_b if o.pr_a.number == pr.number else o.pr_a + # Build value string with truncation to stay under 1024 chars + file_list = o.conflict_files[:3] + files_str = f"Files: `{'`, `'.join(file_list)}`" + if len(o.conflict_files) > 3: + files_str += f" (+{len(o.conflict_files) - 3} more)" + value = f"[{other.title[:100]}]({other.url})\n{files_str}" + # Truncate if still too long + if len(value) > 1024: + value = value[:1020] + "..." + fields.append({ + "name": f"Conflicts with #{other.number}", + "value": value, + "inline": False + }) + + embed = { + "title": f"⚠️ PR #{pr.number} has merge conflicts", + "description": f"[{pr.title}]({pr.url})", + "color": 0xFF0000, + "fields": fields + } + + if len(conflicts) > 25: + embed["footer"] = {"text": f"... and {len(conflicts) - 25} more conflicts"} + + try: + subprocess.run( + ["curl", "-X", "POST", "-H", "Content-Type: application/json", + "--max-time", "10", + "-d", json.dumps({"embeds": [embed]}), webhook_url], + capture_output=True, + timeout=15 + ) + except subprocess.TimeoutExpired: + print("Warning: Discord webhook timed out", file=sys.stderr) + + +# ============================================================================= +# UTILITIES +# ============================================================================= + +def run_gh(args: list[str], check: bool = True) -> subprocess.CompletedProcess: + """Run a gh CLI command.""" + result = subprocess.run( + ["gh"] + args, + capture_output=True, + text=True, + check=False + ) + if check and result.returncode != 0: + print(f"Error running gh {' '.join(args)}: {result.stderr}", file=sys.stderr) + sys.exit(1) + return result + + +def run_git(args: list[str], cwd: str = None, check: bool = True) -> subprocess.CompletedProcess: + """Run a git command.""" + result = subprocess.run( + ["git"] + args, + capture_output=True, + text=True, + cwd=cwd, + check=False + ) + if check and result.returncode != 0: + print(f"Error running git {' '.join(args)}: {result.stderr}", file=sys.stderr) + return result + + +def should_ignore_file(path: str) -> bool: + """Check if a file should be ignored for overlap detection.""" + if path in IGNORE_FILES: + return True + basename = path.split("/")[-1] + return basename in IGNORE_FILES + + +def find_common_prefix(paths: list[str]) -> str: + """Find the common directory prefix of a list of file paths.""" + if not paths: + return "" + if len(paths) == 1: + parts = paths[0].rsplit('/', 1) + return parts[0] + '/' if len(parts) > 1 else "" + + split_paths = [p.split('/') for p in paths] + common = [] + for parts in zip(*split_paths): + if len(set(parts)) == 1: + common.append(parts[0]) + else: + break + + return '/'.join(common) + '/' if common else "" + + +def format_relative_time(iso_timestamp: str) -> str: + """Format an ISO timestamp as relative time.""" + if not iso_timestamp: + return "" + + from datetime import datetime, timezone + try: + dt = datetime.fromisoformat(iso_timestamp.replace('Z', '+00:00')) + now = datetime.now(timezone.utc) + diff = now - dt + + seconds = diff.total_seconds() + if seconds < 60: + return "just now" + elif seconds < 3600: + return f"{int(seconds / 60)}m ago" + elif seconds < 86400: + return f"{int(seconds / 3600)}h ago" + else: + return f"{int(seconds / 86400)}d ago" + except Exception as e: + print(f"Warning: Could not format relative time: {e}", file=sys.stderr) + return "" + + +# ============================================================================= +# DATA CLASSES +# ============================================================================= + +@dataclass +class ChangedFile: + """Represents a file changed in a PR.""" + path: str + additions: list[tuple[int, int]] + deletions: list[tuple[int, int]] + is_rename: bool = False + old_path: str = None + + +@dataclass +class PullRequest: + """Represents a pull request.""" + number: int + title: str + author: str + url: str + head_ref: str + base_ref: str + files: list[str] + changed_ranges: dict[str, ChangedFile] + updated_at: str = None + + +@dataclass +class ConflictInfo: + """Info about a single conflicting file.""" + path: str + conflict_count: int = 0 + conflict_lines: int = 0 + conflict_type: str = "content" + + +@dataclass +class Overlap: + """Represents an overlap between two PRs.""" + pr_a: PullRequest + pr_b: PullRequest + overlapping_files: list[str] + line_overlaps: dict[str, list[tuple[int, int]]] + has_merge_conflict: bool = False + conflict_files: list[str] = None + conflict_details: list[ConflictInfo] = None + conflict_type: str = None + + def __post_init__(self): + if self.conflict_files is None: + self.conflict_files = [] + if self.conflict_details is None: + self.conflict_details = [] + + +# ============================================================================= +# CONSTANTS +# ============================================================================= + +IGNORE_FILES = { + "autogpt_platform/frontend/src/app/api/openapi.json", + "poetry.lock", + "pnpm-lock.yaml", + "package-lock.json", + "yarn.lock", +} + + +# ============================================================================= +# ENTRY POINT +# ============================================================================= + +if __name__ == "__main__": + main() diff --git a/.github/workflows/claude-ci-failure-auto-fix.yml b/.github/workflows/claude-ci-failure-auto-fix.yml index ab07c8ae10..dbca6dc3f3 100644 --- a/.github/workflows/claude-ci-failure-auto-fix.yml +++ b/.github/workflows/claude-ci-failure-auto-fix.yml @@ -40,6 +40,48 @@ jobs: git checkout -b "$BRANCH_NAME" echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT + # Backend Python/Poetry setup (so Claude can run linting/tests) + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Set up Python dependency cache + uses: actions/cache@v5 + with: + path: ~/.cache/pypoetry + key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} + + - name: Install Poetry + run: | + cd autogpt_platform/backend + HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry) + curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install Python dependencies + working-directory: autogpt_platform/backend + run: poetry install + + - name: Generate Prisma Client + working-directory: autogpt_platform/backend + run: poetry run prisma generate && poetry run gen-prisma-stub + + # Frontend Node.js/pnpm setup (so Claude can run linting/tests) + - name: Enable corepack + run: corepack enable + + - name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: "22" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml + + - name: Install JavaScript dependencies + working-directory: autogpt_platform/frontend + run: pnpm install --frozen-lockfile + - name: Get CI failure details id: failure_details uses: actions/github-script@v8 diff --git a/.github/workflows/claude-dependabot.yml b/.github/workflows/claude-dependabot.yml index da37df6de7..274c6d2cab 100644 --- a/.github/workflows/claude-dependabot.yml +++ b/.github/workflows/claude-dependabot.yml @@ -77,27 +77,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index ee901fe5d4..8b8260af6b 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -93,27 +93,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 966243323c..ff535f8496 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -62,7 +62,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} @@ -93,6 +93,6 @@ jobs: exit 1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/pr-overlap-check.yml b/.github/workflows/pr-overlap-check.yml new file mode 100644 index 0000000000..c53f56321b --- /dev/null +++ b/.github/workflows/pr-overlap-check.yml @@ -0,0 +1,39 @@ +name: PR Overlap Detection + +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - dev + - master + +permissions: + contents: read + pull-requests: write + +jobs: + check-overlaps: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Need full history for merge testing + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Configure git + run: | + git config user.email "github-actions[bot]@users.noreply.github.com" + git config user.name "github-actions[bot]" + + - name: Run overlap detection + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # Always succeed - this check informs contributors, it shouldn't block merging + continue-on-error: true + run: | + python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }} diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index fa52ba812a..2711bd2df9 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -104,6 +104,12 @@ TWITTER_CLIENT_SECRET= # Make a new workspace for your OAuth APP -- trust me # https://linear.app/settings/api/applications/new # Callback URL: http://localhost:3000/auth/integrations/oauth_callback +LINEAR_API_KEY= +# Linear project and team IDs for the feature request tracker. +# Find these in your Linear workspace URL: linear.app//project/ +# and in team settings. Used by the chat copilot to file and search feature requests. +LINEAR_FEATURE_REQUEST_PROJECT_ID= +LINEAR_FEATURE_REQUEST_TEAM_ID= LINEAR_CLIENT_ID= LINEAR_CLIENT_SECRET= diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index ace534b730..05a8d4858b 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -66,13 +66,19 @@ ENV POETRY_HOME=/opt/poetry \ DEBIAN_FRONTEND=noninteractive ENV PATH=/opt/poetry/bin:$PATH -# Install Python, FFmpeg, and ImageMagick (required for video processing blocks) +# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use. +# bubblewrap provides OS-level sandbox (whitelist-only FS + no network) +# for the bash_exec MCP tool. # Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc. RUN apt-get update && apt-get install -y --no-install-recommends \ python3.13 \ python3-pip \ ffmpeg \ imagemagick \ + jq \ + ripgrep \ + tree \ + bubblewrap \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3* diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 7d28f3a832..1bb7ff43d0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,5 +1,6 @@ """Chat API routes for chat session management and streaming via SSE.""" +import asyncio import logging import uuid as uuid_module from collections.abc import AsyncGenerator @@ -19,12 +20,14 @@ from backend.copilot.completion_handler import ( from backend.copilot.config import ChatConfig from backend.copilot.executor.utils import enqueue_copilot_task from backend.copilot.model import ( + ChatMessage, ChatSession, + append_and_save_message, create_chat_session, get_chat_session, get_user_sessions, ) -from backend.copilot.response_model import StreamFinish, StreamHeartbeat +from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat from backend.copilot.tools.models import ( AgentDetailsResponse, AgentOutputResponse, @@ -48,6 +51,7 @@ from backend.copilot.tools.models import ( SetupRequirementsResponse, UnderstandingUpdatedResponse, ) +from backend.copilot.tracking import track_user_message from backend.util.exceptions import NotFoundError config = ChatConfig() @@ -240,6 +244,10 @@ async def get_session( active_task, last_message_id = await stream_registry.get_active_task_for_session( session_id, user_id ) + logger.info( + f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, " + f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}" + ) if active_task: # Filter out the in-progress assistant message from the session response. # The client will receive the complete assistant response through the SSE @@ -309,10 +317,9 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - - _session = await _validate_and_get_session(session_id, user_id) # noqa: F841 + await _validate_and_get_session(session_id, user_id) logger.info( - f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms", + f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -321,6 +328,25 @@ async def stream_chat_post( }, ) + # Atomically append user message to session BEFORE creating task to avoid + # race condition where GET_SESSION sees task as "running" but message isn't + # saved yet. append_and_save_message re-fetches inside a lock to prevent + # message loss from concurrent requests. + if request.message: + message = ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, + ) + if request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + logger.info(f"[STREAM] Saving user message to session {session_id}") + await append_and_save_message(session_id, message) + logger.info(f"[STREAM] User message saved for session {session_id}") + # Create a task in the stream registry for reconnection support task_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4()) @@ -336,7 +362,7 @@ async def stream_chat_post( operation_id=operation_id, ) logger.info( - f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms", + f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -453,8 +479,14 @@ async def stream_chat_post( "json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)} }, ) + # Surface error to frontend so it doesn't appear stuck + yield StreamError( + errorText="An error occurred. Please try again.", + code="stream_error", + ).to_sse() + yield StreamFinish().to_sse() finally: - # Unsubscribe when client disconnects or stream ends to prevent resource leak + # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: await stream_registry.unsubscribe_from_task( @@ -698,8 +730,6 @@ async def stream_task( ) async def event_generator() -> AsyncGenerator[str, None]: - import asyncio - heartbeat_interval = 15.0 # Send heartbeat every 15 seconds try: while True: diff --git a/autogpt_platform/backend/backend/blocks/basic.py b/autogpt_platform/backend/backend/blocks/basic.py index f129d2707b..5fdcfb6d82 100644 --- a/autogpt_platform/backend/backend/blocks/basic.py +++ b/autogpt_platform/backend/backend/blocks/basic.py @@ -126,6 +126,7 @@ class PrintToConsoleBlock(Block): output_schema=PrintToConsoleBlock.Output, test_input={"text": "Hello, World!"}, is_sensitive_action=True, + disabled=True, # Disabled per Nick Tindle's request (OPEN-3000) test_output=[ ("output", "Hello, World!"), ("status", "printed"), diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index 808692f97f..04bbe8e60d 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -27,12 +27,11 @@ class ChatConfig(BaseSettings): session_ttl: int = Field(default=43200, description="Session TTL in seconds") # Streaming Configuration - max_context_messages: int = Field( - default=50, ge=1, le=200, description="Maximum context messages" - ) - stream_timeout: int = Field(default=300, description="Stream timeout in seconds") - max_retries: int = Field(default=3, description="Maximum number of retries") + max_retries: int = Field( + default=3, + description="Max retries for fallback path (SDK handles retries internally)", + ) max_agent_runs: int = Field(default=30, description="Maximum number of agent runs") max_agent_schedules: int = Field( default=30, description="Maximum number of agent schedules" @@ -93,6 +92,31 @@ class ChatConfig(BaseSettings): description="Name of the prompt in Langfuse to fetch", ) + # Claude Agent SDK Configuration + use_claude_agent_sdk: bool = Field( + default=True, + description="Use Claude Agent SDK for chat completions", + ) + claude_agent_model: str | None = Field( + default=None, + description="Model for the Claude Agent SDK path. If None, derives from " + "the `model` field by stripping the OpenRouter provider prefix.", + ) + claude_agent_max_buffer_size: int = Field( + default=10 * 1024 * 1024, # 10MB (default SDK is 1MB) + description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. " + "Increase if tool outputs exceed the limit.", + ) + claude_agent_max_subtasks: int = Field( + default=10, + description="Max number of sub-agent Tasks the SDK can spawn per session.", + ) + claude_agent_use_resume: bool = Field( + default=True, + description="Use --resume for multi-turn conversations instead of " + "history compression. Falls back to compression when unavailable.", + ) + # Extended thinking configuration for Claude models thinking_enabled: bool = Field( default=True, @@ -138,6 +162,17 @@ class ChatConfig(BaseSettings): v = os.getenv("CHAT_INTERNAL_API_KEY") return v + @field_validator("use_claude_agent_sdk", mode="before") + @classmethod + def get_use_claude_agent_sdk(cls, v): + """Get use_claude_agent_sdk from environment if not provided.""" + # Check environment variable - default to True if not set + env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower() + if env_val: + return env_val in ("true", "1", "yes", "on") + # Default to True (SDK enabled by default) + return True if v is None else v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/copilot/model.py b/autogpt_platform/backend/backend/copilot/model.py index c9500337eb..b48e471a21 100644 --- a/autogpt_platform/backend/backend/copilot/model.py +++ b/autogpt_platform/backend/backend/copilot/model.py @@ -360,7 +360,7 @@ async def get_chat_session( logger.warning(f"Unexpected cache error for session {session_id}: {e}") # Fall back to database - logger.info(f"Session {session_id} not in cache, checking database") + logger.debug(f"Session {session_id} not in cache, checking database") session = await _get_session_from_db(session_id) if session is None: @@ -540,6 +540,40 @@ async def _save_session_to_db( ) +async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession: + """Atomically append a message to a session and persist it. + + Acquires the session lock, re-fetches the latest session state, + appends the message, and saves — preventing message loss when + concurrent requests modify the same session. + """ + lock = await _get_session_lock(session_id) + + async with lock: + session = await get_chat_session(session_id) + if session is None: + raise ValueError(f"Session {session_id} not found") + + session.messages.append(message) + existing_message_count = await chat_db().get_chat_session_message_count( + session_id + ) + + try: + await _save_session_to_db(session, existing_message_count) + except Exception as e: + raise DatabaseError( + f"Failed to persist message to session {session_id}" + ) from e + + try: + await cache_chat_session(session) + except Exception as e: + logger.warning(f"Cache write failed for session {session_id}: {e}") + + return session + + async def create_chat_session(user_id: str) -> ChatSession: """Create a new chat session and persist it. @@ -647,13 +681,19 @@ async def update_session_title(session_id: str, title: str) -> bool: logger.warning(f"Session {session_id} not found for title update") return False - # Invalidate cache so next fetch gets updated title + # Update title in cache if it exists (instead of invalidating). + # This prevents race conditions where cache invalidation causes + # the frontend to see stale DB data while streaming is still in progress. try: - redis_key = _get_session_cache_key(session_id) - async_redis = await get_redis_async() - await async_redis.delete(redis_key) + cached = await _get_session_from_cache(session_id) + if cached: + cached.title = title + await cache_chat_session(cached) except Exception as e: - logger.warning(f"Failed to invalidate cache for session {session_id}: {e}") + # Not critical - title will be correct on next full cache refresh + logger.warning( + f"Failed to update title in cache for session {session_id}: {e}" + ) return True except Exception as e: diff --git a/autogpt_platform/backend/backend/copilot/sdk/__init__.py b/autogpt_platform/backend/backend/copilot/sdk/__init__.py new file mode 100644 index 0000000000..7d9d6371e9 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/__init__.py @@ -0,0 +1,14 @@ +"""Claude Agent SDK integration for CoPilot. + +This module provides the integration layer between the Claude Agent SDK +and the existing CoPilot tool system, enabling drop-in replacement of +the current LLM orchestration with the battle-tested Claude Agent SDK. +""" + +from .service import stream_chat_completion_sdk +from .tool_adapter import create_copilot_mcp_server + +__all__ = [ + "stream_chat_completion_sdk", + "create_copilot_mcp_server", +] diff --git a/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py b/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py new file mode 100644 index 0000000000..7a3976ae42 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py @@ -0,0 +1,201 @@ +"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + +This module provides the adapter layer that converts streaming messages from +the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that +the frontend expects. +""" + +import json +import logging +import uuid + +from claude_agent_sdk import ( + AssistantMessage, + Message, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.copilot.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) + +from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output + +logger = logging.getLogger(__name__) + + +class SDKResponseAdapter: + """Adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + + This class maintains state during a streaming session to properly track + text blocks, tool calls, and message lifecycle. + """ + + def __init__(self, message_id: str | None = None): + self.message_id = message_id or str(uuid.uuid4()) + self.text_block_id = str(uuid.uuid4()) + self.has_started_text = False + self.has_ended_text = False + self.current_tool_calls: dict[str, dict[str, str]] = {} + self.task_id: str | None = None + self.step_open = False + + def set_task_id(self, task_id: str) -> None: + """Set the task ID for reconnection support.""" + self.task_id = task_id + + def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]: + """Convert a single SDK message to Vercel AI SDK format.""" + responses: list[StreamBaseResponse] = [] + + if isinstance(sdk_message, SystemMessage): + if sdk_message.subtype == "init": + responses.append( + StreamStart(messageId=self.message_id, taskId=self.task_id) + ) + # Open the first step (matches non-SDK: StreamStart then StreamStartStep) + responses.append(StreamStartStep()) + self.step_open = True + + elif isinstance(sdk_message, AssistantMessage): + # After tool results, the SDK sends a new AssistantMessage for the + # next LLM turn. Open a new step if the previous one was closed. + if not self.step_open: + responses.append(StreamStartStep()) + self.step_open = True + + for block in sdk_message.content: + if isinstance(block, TextBlock): + if block.text: + self._ensure_text_started(responses) + responses.append( + StreamTextDelta(id=self.text_block_id, delta=block.text) + ) + + elif isinstance(block, ToolUseBlock): + self._end_text_if_open(responses) + + # Strip MCP prefix so frontend sees "find_block" + # instead of "mcp__copilot__find_block". + tool_name = block.name.removeprefix(MCP_TOOL_PREFIX) + + responses.append( + StreamToolInputStart(toolCallId=block.id, toolName=tool_name) + ) + responses.append( + StreamToolInputAvailable( + toolCallId=block.id, + toolName=tool_name, + input=block.input, + ) + ) + self.current_tool_calls[block.id] = {"name": tool_name} + + elif isinstance(sdk_message, UserMessage): + # UserMessage carries tool results back from tool execution. + content = sdk_message.content + blocks = content if isinstance(content, list) else [] + for block in blocks: + if isinstance(block, ToolResultBlock) and block.tool_use_id: + tool_info = self.current_tool_calls.get(block.tool_use_id, {}) + tool_name = tool_info.get("name", "unknown") + + # Prefer the stashed full output over the SDK's + # (potentially truncated) ToolResultBlock content. + # The SDK truncates large results, writing them to disk, + # which breaks frontend widget parsing. + output = pop_pending_tool_output(tool_name) or ( + _extract_tool_output(block.content) + ) + + responses.append( + StreamToolOutputAvailable( + toolCallId=block.tool_use_id, + toolName=tool_name, + output=output, + success=not (block.is_error or False), + ) + ) + + # Close the current step after tool results — the next + # AssistantMessage will open a new step for the continuation. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + elif isinstance(sdk_message, ResultMessage): + self._end_text_if_open(responses) + # Close the step before finishing. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + if sdk_message.subtype == "success": + responses.append(StreamFinish()) + elif sdk_message.subtype in ("error", "error_during_execution"): + error_msg = getattr(sdk_message, "result", None) or "Unknown error" + responses.append( + StreamError(errorText=str(error_msg), code="sdk_error") + ) + responses.append(StreamFinish()) + else: + logger.warning( + f"Unexpected ResultMessage subtype: {sdk_message.subtype}" + ) + responses.append(StreamFinish()) + + else: + logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}") + + return responses + + def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None: + """Start (or restart) a text block if needed.""" + if not self.has_started_text or self.has_ended_text: + if self.has_ended_text: + self.text_block_id = str(uuid.uuid4()) + self.has_ended_text = False + responses.append(StreamTextStart(id=self.text_block_id)) + self.has_started_text = True + + def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None: + """End the current text block if one is open.""" + if self.has_started_text and not self.has_ended_text: + responses.append(StreamTextEnd(id=self.text_block_id)) + self.has_ended_text = True + + +def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str: + """Extract a string output from a ToolResultBlock's content field.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [item.get("text", "") for item in content if item.get("type") == "text"] + if parts: + return "".join(parts) + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) + if content is None: + return "" + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) diff --git a/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py new file mode 100644 index 0000000000..7555eb8046 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py @@ -0,0 +1,366 @@ +"""Unit tests for the SDK response adapter.""" + +from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.copilot.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) + +from .response_adapter import SDKResponseAdapter +from .tool_adapter import MCP_TOOL_PREFIX + + +def _adapter() -> SDKResponseAdapter: + a = SDKResponseAdapter(message_id="msg-1") + a.set_task_id("task-1") + return a + + +# -- SystemMessage ----------------------------------------------------------- + + +def test_system_init_emits_start_and_step(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="init", data={})) + assert len(results) == 2 + assert isinstance(results[0], StreamStart) + assert results[0].messageId == "msg-1" + assert results[0].taskId == "task-1" + assert isinstance(results[1], StreamStartStep) + + +def test_system_non_init_emits_nothing(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="other", data={})) + assert results == [] + + +# -- AssistantMessage with TextBlock ----------------------------------------- + + +def test_text_block_emits_step_start_and_delta(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="hello")], model="test") + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "hello" + + +def test_empty_text_block_emits_only_step(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="")], model="test") + results = adapter.convert_message(msg) + # Empty text skipped, but step still opens + assert len(results) == 1 + assert isinstance(results[0], StreamStartStep) + + +def test_multiple_text_deltas_reuse_block_id(): + adapter = _adapter() + msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test") + msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test") + r1 = adapter.convert_message(msg1) + r2 = adapter.convert_message(msg2) + # First gets step+start+delta, second only delta (block & step already started) + assert len(r1) == 3 + assert isinstance(r1[0], StreamStartStep) + assert isinstance(r1[1], StreamTextStart) + assert len(r2) == 1 + assert isinstance(r2[0], StreamTextDelta) + assert r1[1].id == r2[0].id # same block ID + + +# -- AssistantMessage with ToolUseBlock -------------------------------------- + + +def test_tool_use_emits_input_start_and_available(): + """Tool names arrive with MCP prefix and should be stripped for the frontend.""" + adapter = _adapter() + msg = AssistantMessage( + content=[ + ToolUseBlock( + id="tool-1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"q": "x"}, + ) + ], + model="test", + ) + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamToolInputStart) + assert results[1].toolCallId == "tool-1" + assert results[1].toolName == "find_agent" # prefix stripped + assert isinstance(results[2], StreamToolInputAvailable) + assert results[2].toolName == "find_agent" # prefix stripped + assert results[2].input == {"q": "x"} + + +def test_text_then_tool_ends_text_block(): + adapter = _adapter() + text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test") + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + adapter.convert_message(text_msg) # opens step + text + results = adapter.convert_message(tool_msg) + # Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamToolInputStart) + + +# -- UserMessage with ToolResultBlock ---------------------------------------- + + +def test_tool_result_emits_output_and_finish_step(): + adapter = _adapter() + # First register the tool call (opens step) — SDK sends prefixed name + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})], + model="test", + ) + adapter.convert_message(tool_msg) + + # Now send tool result + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")] + ) + results = adapter.convert_message(result_msg) + assert len(results) == 2 + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].toolCallId == "t1" + assert results[0].toolName == "find_agent" # prefix stripped + assert results[0].output == "found 3 agents" + assert results[0].success is True + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_error(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={}) + ], + model="test", + ) + ) + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].success is False + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_list_content(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + result_msg = UserMessage( + content=[ + ToolResultBlock( + tool_use_id="t1", + content=[ + {"type": "text", "text": "line1"}, + {"type": "text", "text": "line2"}, + ], + ) + ] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].output == "line1line2" + assert isinstance(results[1], StreamFinishStep) + + +def test_string_user_message_ignored(): + """A plain string UserMessage (not tool results) produces no output.""" + adapter = _adapter() + results = adapter.convert_message(UserMessage(content="hello")) + assert results == [] + + +# -- ResultMessage ----------------------------------------------------------- + + +def test_result_success_emits_finish_step_and_finish(): + adapter = _adapter() + # Start some text first (opens step) + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="done")], model="test") + ) + msg = ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="s1", + ) + results = adapter.convert_message(msg) + # TextEnd + FinishStep + StreamFinish + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamFinishStep) + assert isinstance(results[2], StreamFinish) + + +def test_result_error_emits_error_and_finish(): + adapter = _adapter() + msg = ResultMessage( + subtype="error", + duration_ms=100, + duration_api_ms=50, + is_error=True, + num_turns=0, + session_id="s1", + result="API rate limited", + ) + results = adapter.convert_message(msg) + # No step was open, so no FinishStep — just Error + Finish + assert len(results) == 2 + assert isinstance(results[0], StreamError) + assert "API rate limited" in results[0].errorText + assert isinstance(results[1], StreamFinish) + + +# -- Text after tools (new block ID) ---------------------------------------- + + +def test_text_after_tool_gets_new_block_id(): + adapter = _adapter() + # Text -> Tool -> ToolResult -> Text should get a new text block ID and step + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="before")], model="test") + ) + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + # Send tool result (closes step) + adapter.convert_message( + UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")]) + ) + results = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="after")], model="test") + ) + # Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "after" + + +# -- Full conversation flow -------------------------------------------------- + + +def test_full_conversation_flow(): + """Simulate a complete conversation: init -> text -> tool -> result -> text -> finish.""" + adapter = _adapter() + all_responses: list[StreamBaseResponse] = [] + + # 1. Init + all_responses.extend( + adapter.convert_message(SystemMessage(subtype="init", data={})) + ) + # 2. Assistant text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="Let me search")], model="test") + ) + ) + # 3. Tool use + all_responses.extend( + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock( + id="t1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"query": "email"}, + ) + ], + model="test", + ) + ) + ) + # 4. Tool result + all_responses.extend( + adapter.convert_message( + UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")] + ) + ) + ) + # 5. More text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="I found 2")], model="test") + ) + ) + # 6. Result + all_responses.extend( + adapter.convert_message( + ResultMessage( + subtype="success", + duration_ms=500, + duration_api_ms=400, + is_error=False, + num_turns=2, + session_id="s1", + ) + ) + ) + + types = [type(r).__name__ for r in all_responses] + assert types == [ + "StreamStart", + "StreamStartStep", # step 1: text + tool call + "StreamTextStart", + "StreamTextDelta", # "Let me search" + "StreamTextEnd", # closed before tool + "StreamToolInputStart", + "StreamToolInputAvailable", + "StreamToolOutputAvailable", # tool result + "StreamFinishStep", # step 1 closed after tool result + "StreamStartStep", # step 2: continuation text + "StreamTextStart", # new block after tool + "StreamTextDelta", # "I found 2" + "StreamTextEnd", # closed by result + "StreamFinishStep", # step 2 closed + "StreamFinish", + ] diff --git a/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py new file mode 100644 index 0000000000..5224400f96 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py @@ -0,0 +1,335 @@ +"""Security hooks for Claude Agent SDK integration. + +This module provides security hooks that validate tool calls before execution, +ensuring multi-user isolation and preventing unauthorized operations. +""" + +import json +import logging +import os +import re +from collections.abc import Callable +from typing import Any, cast + +from .tool_adapter import MCP_TOOL_PREFIX + +logger = logging.getLogger(__name__) + +# Tools that are blocked entirely (CLI/system access). +# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked +# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead, +# which has kernel-level network isolation (unshare --net). +BLOCKED_TOOLS = { + "Bash", + "bash", + "shell", + "exec", + "terminal", + "command", +} + +# Tools allowed only when their path argument stays within the SDK workspace. +# The SDK uses these to handle oversized tool results (writes to tool-results/ +# files, then reads them back) and for workspace file operations. +WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"} + +# Dangerous patterns in tool inputs +DANGEROUS_PATTERNS = [ + r"sudo", + r"rm\s+-rf", + r"dd\s+if=", + r"/etc/passwd", + r"/etc/shadow", + r"chmod\s+777", + r"curl\s+.*\|.*sh", + r"wget\s+.*\|.*sh", + r"eval\s*\(", + r"exec\s*\(", + r"__import__", + r"os\.system", + r"subprocess", +] + + +def _deny(reason: str) -> dict[str, Any]: + """Return a hook denial response.""" + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": reason, + } + } + + +def _validate_workspace_path( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None +) -> dict[str, Any]: + """Validate that a workspace-scoped tool only accesses allowed paths. + + Allowed directories: + - The SDK working directory (``/tmp/copilot-/``) + - The SDK tool-results directory (``~/.claude/projects/…/tool-results/``) + """ + path = tool_input.get("file_path") or tool_input.get("path") or "" + if not path: + # Glob/Grep without a path default to cwd which is already sandboxed + return {} + + # Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM + # naturally uses relative paths like "test.txt" instead of absolute ones). + # Tilde paths (~/) are home-dir references, not relative — expand first. + if path.startswith("~"): + resolved = os.path.realpath(os.path.expanduser(path)) + elif not os.path.isabs(path) and sdk_cwd: + resolved = os.path.realpath(os.path.join(sdk_cwd, path)) + else: + resolved = os.path.realpath(path) + + # Allow access within the SDK working directory + if sdk_cwd: + norm_cwd = os.path.realpath(sdk_cwd) + if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd: + return {} + + # Allow access to ~/.claude/projects/*/tool-results/ (big tool results) + claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects")) + tool_results_seg = os.sep + "tool-results" + os.sep + if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved: + return {} + + logger.warning( + f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})" + ) + workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else "" + return _deny( + f"[SECURITY] Tool '{tool_name}' can only access files within the workspace " + f"directory.{workspace_hint} " + "This is enforced by the platform and cannot be bypassed." + ) + + +def _validate_tool_access( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None +) -> dict[str, Any]: + """Validate that a tool call is allowed. + + Returns: + Empty dict to allow, or dict with hookSpecificOutput to deny + """ + # Block forbidden tools + if tool_name in BLOCKED_TOOLS: + logger.warning(f"Blocked tool access attempt: {tool_name}") + return _deny( + f"[SECURITY] Tool '{tool_name}' is blocked for security. " + "This is enforced by the platform and cannot be bypassed. " + "Use the CoPilot-specific MCP tools instead." + ) + + # Workspace-scoped tools: allowed only within the SDK workspace directory + if tool_name in WORKSPACE_SCOPED_TOOLS: + return _validate_workspace_path(tool_name, tool_input, sdk_cwd) + + # Check for dangerous patterns in tool input + # Use json.dumps for predictable format (str() produces Python repr) + input_str = json.dumps(tool_input) if tool_input else "" + + for pattern in DANGEROUS_PATTERNS: + if re.search(pattern, input_str, re.IGNORECASE): + logger.warning( + f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}" + ) + return _deny( + "[SECURITY] Input contains a blocked pattern. " + "This is enforced by the platform and cannot be bypassed." + ) + + return {} + + +def _validate_user_isolation( + tool_name: str, tool_input: dict[str, Any], user_id: str | None +) -> dict[str, Any]: + """Validate that tool calls respect user isolation.""" + # For workspace file tools, ensure path doesn't escape + if "workspace" in tool_name.lower(): + path = tool_input.get("path", "") or tool_input.get("file_path", "") + if path: + # Check for path traversal + if ".." in path or path.startswith("/"): + logger.warning( + f"Blocked path traversal attempt: {path} by user {user_id}" + ) + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": "Path traversal not allowed", + } + } + + return {} + + +def create_security_hooks( + user_id: str | None, + sdk_cwd: str | None = None, + max_subtasks: int = 3, + on_stop: Callable[[str, str], None] | None = None, +) -> dict[str, Any]: + """Create the security hooks configuration for Claude Agent SDK. + + Includes security validation and observability hooks: + - PreToolUse: Security validation before tool execution + - PostToolUse: Log successful tool executions + - PostToolUseFailure: Log and handle failed tool executions + - PreCompact: Log context compaction events (SDK handles compaction automatically) + - Stop: Capture transcript path for stateless resume (when *on_stop* is provided) + + Args: + user_id: Current user ID for isolation validation + sdk_cwd: SDK working directory for workspace-scoped tool validation + max_subtasks: Maximum Task (sub-agent) spawns allowed per session + on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when + the SDK finishes processing — used to read the JSONL transcript + before the CLI process exits. + + Returns: + Hooks configuration dict for ClaudeAgentOptions + """ + try: + from claude_agent_sdk import HookMatcher + from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput + + # Per-session counter for Task sub-agent spawns + task_spawn_count = 0 + + async def pre_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Combined pre-tool-use validation hook.""" + nonlocal task_spawn_count + _ = context # unused but required by signature + tool_name = cast(str, input_data.get("tool_name", "")) + tool_input = cast(dict[str, Any], input_data.get("tool_input", {})) + + # Rate-limit Task (sub-agent) spawns per session + if tool_name == "Task": + task_spawn_count += 1 + if task_spawn_count > max_subtasks: + logger.warning( + f"[SDK] Task limit reached ({max_subtasks}), user={user_id}" + ) + return cast( + SyncHookJSONOutput, + _deny( + f"Maximum {max_subtasks} sub-tasks per session. " + "Please continue in the main conversation." + ), + ) + + # Strip MCP prefix for consistent validation + is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX) + clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX) + + # Only block non-CoPilot tools; our MCP-registered tools + # (including Read for oversized results) are already sandboxed. + if not is_copilot_tool: + result = _validate_tool_access(clean_name, tool_input, sdk_cwd) + if result: + return cast(SyncHookJSONOutput, result) + + # Validate user isolation + result = _validate_user_isolation(clean_name, tool_input, user_id) + if result: + return cast(SyncHookJSONOutput, result) + + logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log successful tool executions for observability.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_failure_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log failed tool executions for debugging.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + error = input_data.get("error", "Unknown error") + logger.warning( + f"[SDK] Tool failed: {tool_name}, error={error}, " + f"user={user_id}, tool_use_id={tool_use_id}" + ) + return cast(SyncHookJSONOutput, {}) + + async def pre_compact_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log when SDK triggers context compaction. + + The SDK automatically compacts conversation history when it grows too large. + This hook provides visibility into when compaction happens. + """ + _ = context, tool_use_id + trigger = input_data.get("trigger", "auto") + logger.info( + f"[SDK] Context compaction triggered: {trigger}, user={user_id}" + ) + return cast(SyncHookJSONOutput, {}) + + # --- Stop hook: capture transcript path for stateless resume --- + async def stop_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Capture transcript path when SDK finishes processing. + + The Stop hook fires while the CLI process is still alive, giving us + a reliable window to read the JSONL transcript before SIGTERM. + """ + _ = context, tool_use_id + transcript_path = cast(str, input_data.get("transcript_path", "")) + sdk_session_id = cast(str, input_data.get("session_id", "")) + + if transcript_path and on_stop: + logger.info( + f"[SDK] Stop hook: transcript_path={transcript_path}, " + f"sdk_session_id={sdk_session_id[:12]}..." + ) + on_stop(transcript_path, sdk_session_id) + + return cast(SyncHookJSONOutput, {}) + + hooks: dict[str, Any] = { + "PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])], + "PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])], + "PostToolUseFailure": [ + HookMatcher(matcher="*", hooks=[post_tool_failure_hook]) + ], + "PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])], + } + + if on_stop is not None: + hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])] + + return hooks + except ImportError: + # Fallback for when SDK isn't available - return empty hooks + logger.warning("claude-agent-sdk not available, security hooks disabled") + return {} diff --git a/autogpt_platform/backend/backend/copilot/sdk/security_hooks_test.py b/autogpt_platform/backend/backend/copilot/sdk/security_hooks_test.py new file mode 100644 index 0000000000..e1891cf1bd --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/security_hooks_test.py @@ -0,0 +1,190 @@ +"""Tests for SDK security hooks — workspace paths, tool access, and deny messages. + +These are pure unit tests with no external dependencies (no SDK, no DB, no server). +They validate that the security hooks correctly block unauthorized paths, +tool access, and dangerous input patterns. +""" + +import os + +from .security_hooks import _validate_tool_access, _validate_user_isolation + +SDK_CWD = "/tmp/copilot-abc123" + + +def _is_denied(result: dict) -> bool: + hook = result.get("hookSpecificOutput", {}) + return hook.get("permissionDecision") == "deny" + + +def _reason(result: dict) -> str: + return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "") + + +# -- Blocked tools ----------------------------------------------------------- + + +def test_blocked_tools_denied(): + for tool in ("bash", "shell", "exec", "terminal", "command"): + result = _validate_tool_access(tool, {}) + assert _is_denied(result), f"{tool} should be blocked" + + +def test_unknown_tool_allowed(): + result = _validate_tool_access("SomeCustomTool", {}) + assert result == {} + + +# -- Workspace-scoped tools -------------------------------------------------- + + +def test_read_within_workspace_allowed(): + result = _validate_tool_access( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_write_within_workspace_allowed(): + result = _validate_tool_access( + "Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_edit_within_workspace_allowed(): + result = _validate_tool_access( + "Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_glob_within_workspace_allowed(): + result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_grep_within_workspace_allowed(): + result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_outside_workspace_denied(): + result = _validate_tool_access( + "Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_write_outside_workspace_denied(): + result = _validate_tool_access( + "Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_traversal_attack_denied(): + result = _validate_tool_access( + "Read", + {"file_path": f"{SDK_CWD}/../../etc/passwd"}, + sdk_cwd=SDK_CWD, + ) + assert _is_denied(result) + + +def test_no_path_allowed(): + """Glob/Grep without a path argument defaults to cwd — should pass.""" + result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_no_cwd_denies_absolute(): + """If no sdk_cwd is set, absolute paths are denied.""" + result = _validate_tool_access("Read", {"file_path": "/tmp/anything"}) + assert _is_denied(result) + + +# -- Tool-results directory -------------------------------------------------- + + +def test_read_tool_results_allowed(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_claude_projects_without_tool_results_denied(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Built-in Bash is blocked (use bash_exec MCP tool instead) --------------- + + +def test_bash_builtin_always_blocked(): + """SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead.""" + result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Dangerous patterns ------------------------------------------------------ + + +def test_dangerous_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"}) + assert _is_denied(result) + + +def test_subprocess_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"}) + assert _is_denied(result) + + +# -- User isolation ---------------------------------------------------------- + + +def test_workspace_path_traversal_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_absolute_path_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "/etc/passwd"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_normal_path_allowed(): + result = _validate_user_isolation( + "workspace_read", {"path": "src/main.py"}, user_id="user-1" + ) + assert result == {} + + +def test_non_workspace_tool_passes_isolation(): + result = _validate_user_isolation( + "find_agent", {"query": "email"}, user_id="user-1" + ) + assert result == {} + + +# -- Deny message quality ---------------------------------------------------- + + +def test_blocked_tool_message_clarity(): + """Deny messages must include [SECURITY] and 'cannot be bypassed'.""" + reason = _reason(_validate_tool_access("bash", {})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + +def test_bash_builtin_blocked_message_clarity(): + reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py new file mode 100644 index 0000000000..9c1b05198e --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -0,0 +1,751 @@ +"""Claude Agent SDK service layer for CoPilot chat completions.""" + +import asyncio +import json +import logging +import os +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Any + +from backend.util.exceptions import NotFoundError + +from .. import stream_registry +from ..config import ChatConfig +from ..model import ( + ChatMessage, + ChatSession, + get_chat_session, + update_session_title, + upsert_chat_session, +) +from ..response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamStart, + StreamTextDelta, + StreamToolInputAvailable, + StreamToolOutputAvailable, +) +from ..service import ( + _build_system_prompt, + _execute_long_running_tool_with_streaming, + _generate_session_title, +) +from ..tools.models import OperationPendingResponse, OperationStartedResponse +from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path +from ..tracking import track_user_message +from .response_adapter import SDKResponseAdapter +from .security_hooks import create_security_hooks +from .tool_adapter import ( + COPILOT_TOOL_NAMES, + LongRunningCallback, + create_copilot_mcp_server, + set_execution_context, +) +from .transcript import ( + download_transcript, + read_transcript_file, + upload_transcript, + validate_transcript, + write_transcript_to_tempfile, +) + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Set to hold background tasks to prevent garbage collection +_background_tasks: set[asyncio.Task[Any]] = set() + + +@dataclass +class CapturedTranscript: + """Info captured by the SDK Stop hook for stateless --resume.""" + + path: str = "" + sdk_session_id: str = "" + + @property + def available(self) -> bool: + return bool(self.path) + + +_SDK_CWD_PREFIX = WORKSPACE_PREFIX + +# Appended to the system prompt to inform the agent about available tools. +# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead, +# which has kernel-level network isolation (unshare --net). +_SDK_TOOL_SUPPLEMENT = """ + +## Tool notes + +- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool + for shell commands — it runs in a network-isolated sandbox. +- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the + same working directory. Files created by one are readable by the other. + These files are **ephemeral** — they exist only for the current session. +- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file` + for files that should persist across sessions (stored in cloud storage). +- Long-running tools (create_agent, edit_agent, etc.) are handled + asynchronously. You will receive an immediate response; the actual result + is delivered to the user via a background stream. +""" + + +def _build_long_running_callback(user_id: str | None) -> LongRunningCallback: + """Build a callback that delegates long-running tools to the non-SDK infrastructure. + + Long-running tools (create_agent, edit_agent, etc.) are delegated to the + existing background infrastructure: stream_registry (Redis Streams), + database persistence, and SSE reconnection. This means results survive + page refreshes / pod restarts, and the frontend shows the proper loading + widget with progress updates. + + The returned callback matches the ``LongRunningCallback`` signature: + ``(tool_name, args, session) -> MCP response dict``. + """ + + async def _callback( + tool_name: str, args: dict[str, Any], session: ChatSession + ) -> dict[str, Any]: + operation_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}" + session_id = session.session_id + + # --- Build user-friendly messages (matches non-SDK service) --- + if tool_name == "create_agent": + desc = args.get("description", "") + desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc + pending_msg = ( + f"Creating your agent: {desc_preview}" + if desc_preview + else "Creating agent... This may take a few minutes." + ) + started_msg = ( + "Agent creation started. You can close this tab - " + "check your library in a few minutes." + ) + elif tool_name == "edit_agent": + changes = args.get("changes", "") + changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes + pending_msg = ( + f"Editing agent: {changes_preview}" + if changes_preview + else "Editing agent... This may take a few minutes." + ) + started_msg = ( + "Agent edit started. You can close this tab - " + "check your library in a few minutes." + ) + else: + pending_msg = f"Running {tool_name}... This may take a few minutes." + started_msg = ( + f"{tool_name} started. You can close this tab - " + "check back in a few minutes." + ) + + # --- Register task in Redis for SSE reconnection --- + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # --- Save OperationPendingResponse to chat history --- + pending_message = ChatMessage( + role="tool", + content=OperationPendingResponse( + message=pending_msg, + operation_id=operation_id, + tool_name=tool_name, + ).model_dump_json(), + tool_call_id=tool_call_id, + ) + session.messages.append(pending_message) + await upsert_chat_session(session) + + # --- Spawn background task (reuses non-SDK infrastructure) --- + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( + tool_name=tool_name, + parameters=args, + tool_call_id=tool_call_id, + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + user_id=user_id, + ) + ) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + logger.info( + f"[SDK] Long-running tool {tool_name} delegated to background " + f"(operation_id={operation_id}, task_id={task_id})" + ) + + # --- Return OperationStartedResponse as MCP tool result --- + # This flows through SDK → response adapter → frontend, triggering + # the loading widget with SSE reconnection support. + started_json = OperationStartedResponse( + message=started_msg, + operation_id=operation_id, + tool_name=tool_name, + task_id=task_id, + ).model_dump_json() + + return { + "content": [{"type": "text", "text": started_json}], + "isError": False, + } + + return _callback + + +def _resolve_sdk_model() -> str | None: + """Resolve the model name for the Claude Agent SDK CLI. + + Uses ``config.claude_agent_model`` if set, otherwise derives from + ``config.model`` by stripping the OpenRouter provider prefix (e.g., + ``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``). + """ + if config.claude_agent_model: + return config.claude_agent_model + model = config.model + if "/" in model: + return model.split("/", 1)[1] + return model + + +def _build_sdk_env() -> dict[str, str]: + """Build env vars for the SDK CLI process. + + Routes API calls through OpenRouter (or a custom base_url) using + the same ``config.api_key`` / ``config.base_url`` as the non-SDK path. + This gives per-call token and cost tracking on the OpenRouter dashboard. + + Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth + token are both present — otherwise returns an empty dict so the SDK + falls back to its default credentials. + """ + env: dict[str, str] = {} + if config.api_key and config.base_url: + # Strip /v1 suffix — SDK expects the base URL without a version path + base = config.base_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + if not base or not base.startswith("http"): + # Invalid base_url — don't override SDK defaults + return env + env["ANTHROPIC_BASE_URL"] = base + env["ANTHROPIC_AUTH_TOKEN"] = config.api_key + # Must be explicitly empty so the CLI uses AUTH_TOKEN instead + env["ANTHROPIC_API_KEY"] = "" + return env + + +def _make_sdk_cwd(session_id: str) -> str: + """Create a safe, session-specific working directory path. + + Delegates to :func:`~backend.copilot.tools.sandbox.make_session_path` + (single source of truth for path sanitization) and adds a defence-in-depth + assertion. + """ + cwd = make_session_path(session_id) + # Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer + cwd = os.path.normpath(cwd) + if not cwd.startswith(_SDK_CWD_PREFIX): + raise ValueError(f"SDK cwd escaped prefix: {cwd}") + return cwd + + +def _cleanup_sdk_tool_results(cwd: str) -> None: + """Remove SDK tool-result files for a specific session working directory. + + The SDK creates tool-result files under ~/.claude/projects//tool-results/. + We clean only the specific cwd's results to avoid race conditions between + concurrent sessions. + + Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id. + """ + import shutil + + # Validate cwd is under the expected prefix + normalized = os.path.normpath(cwd) + if not normalized.startswith(_SDK_CWD_PREFIX): + logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}") + return + + # SDK encodes the cwd path by replacing '/' with '-' + encoded_cwd = normalized.replace("/", "-") + + # Construct the project directory path (known-safe home expansion) + claude_projects = os.path.expanduser("~/.claude/projects") + project_dir = os.path.join(claude_projects, encoded_cwd) + + # Security check 3: Validate project_dir is under ~/.claude/projects + project_dir = os.path.normpath(project_dir) + if not project_dir.startswith(claude_projects): + logger.warning( + f"[SDK] Rejecting cleanup for escaped project path: {project_dir}" + ) + return + + results_dir = os.path.join(project_dir, "tool-results") + if os.path.isdir(results_dir): + for filename in os.listdir(results_dir): + file_path = os.path.join(results_dir, filename) + try: + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + pass + + # Also clean up the temp cwd directory itself + try: + shutil.rmtree(normalized, ignore_errors=True) + except OSError: + pass + + +async def _compress_conversation_history( + session: ChatSession, +) -> list[ChatMessage]: + """Compress prior conversation messages if they exceed the token threshold. + + Uses the shared compress_context() from prompt.py which supports: + - LLM summarization of old messages (keeps recent ones intact) + - Progressive content truncation as fallback + - Middle-out deletion as last resort + + Returns the compressed prior messages (everything except the current message). + """ + prior = session.messages[:-1] + if len(prior) < 2: + return prior + + from backend.util.prompt import compress_context + + # Convert ChatMessages to dicts for compress_context + messages_dict = [] + for msg in prior: + msg_dict: dict[str, Any] = {"role": msg.role} + if msg.content: + msg_dict["content"] = msg.content + if msg.tool_calls: + msg_dict["tool_calls"] = msg.tool_calls + if msg.tool_call_id: + msg_dict["tool_call_id"] = msg.tool_call_id + messages_dict.append(msg_dict) + + try: + import openai + + async with openai.AsyncOpenAI( + api_key=config.api_key, base_url=config.base_url, timeout=30.0 + ) as client: + result = await compress_context( + messages=messages_dict, + model=config.model, + client=client, + ) + except Exception as e: + logger.warning(f"[SDK] Context compression with LLM failed: {e}") + # Fall back to truncation-only (no LLM summarization) + result = await compress_context( + messages=messages_dict, + model=config.model, + client=None, + ) + + if result.was_compacted: + logger.info( + f"[SDK] Context compacted: {result.original_token_count} -> " + f"{result.token_count} tokens " + f"({result.messages_summarized} summarized, " + f"{result.messages_dropped} dropped)" + ) + # Convert compressed dicts back to ChatMessages + return [ + ChatMessage( + role=m["role"], + content=m.get("content"), + tool_calls=m.get("tool_calls"), + tool_call_id=m.get("tool_call_id"), + ) + for m in result.messages + ] + + return prior + + +def _format_conversation_context(messages: list[ChatMessage]) -> str | None: + """Format conversation messages into a context prefix for the user message. + + Returns a string like: + + User: hello + You responded: Hi! How can I help? + + + Returns None if there are no messages to format. + """ + if not messages: + return None + + lines: list[str] = [] + for msg in messages: + if not msg.content: + continue + if msg.role == "user": + lines.append(f"User: {msg.content}") + elif msg.role == "assistant": + lines.append(f"You responded: {msg.content}") + # Skip tool messages — they're internal details + + if not lines: + return None + + return "\n" + "\n".join(lines) + "\n" + + +async def stream_chat_completion_sdk( + session_id: str, + message: str | None = None, + tool_call_response: str | None = None, # noqa: ARG001 + is_user_message: bool = True, + user_id: str | None = None, + retry_count: int = 0, # noqa: ARG001 + session: ChatSession | None = None, + context: dict[str, str] | None = None, # noqa: ARG001 +) -> AsyncGenerator[StreamBaseResponse, None]: + """Stream chat completion using Claude Agent SDK. + + Drop-in replacement for stream_chat_completion with improved reliability. + """ + + if session is None: + session = await get_chat_session(session_id, user_id) + + if not session: + raise NotFoundError( + f"Session {session_id} not found. Please create a new session first." + ) + + if message: + session.messages.append( + ChatMessage( + role="user" if is_user_message else "assistant", content=message + ) + ) + if is_user_message: + track_user_message( + user_id=user_id, session_id=session_id, message_length=len(message) + ) + + session = await upsert_chat_session(session) + + # Generate title for new sessions (first user message) + if is_user_message and not session.title: + user_messages = [m for m in session.messages if m.role == "user"] + if len(user_messages) == 1: + first_message = user_messages[0].content or message or "" + if first_message: + task = asyncio.create_task( + _update_title_async(session_id, first_message, user_id) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + + # Build system prompt (reuses non-SDK path with Langfuse support) + has_history = len(session.messages) > 1 + system_prompt, _ = await _build_system_prompt( + user_id, has_conversation_history=has_history + ) + system_prompt += _SDK_TOOL_SUPPLEMENT + message_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + + yield StreamStart(messageId=message_id, taskId=task_id) + + stream_completed = False + # Initialise sdk_cwd before the try so the finally can reference it + # even if _make_sdk_cwd raises (in that case it stays as ""). + sdk_cwd = "" + use_resume = False + + try: + # Use a session-specific temp dir to avoid cleanup race conditions + # between concurrent sessions. + sdk_cwd = _make_sdk_cwd(session_id) + os.makedirs(sdk_cwd, exist_ok=True) + + set_execution_context( + user_id, + session, + long_running_callback=_build_long_running_callback(user_id), + ) + try: + from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient + + # Fail fast when no API credentials are available at all + sdk_env = _build_sdk_env() + if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"): + raise RuntimeError( + "No API key configured. Set OPEN_ROUTER_API_KEY " + "(or CHAT_API_KEY) for OpenRouter routing, " + "or ANTHROPIC_API_KEY for direct Anthropic access." + ) + + mcp_server = create_copilot_mcp_server() + + sdk_model = _resolve_sdk_model() + + # --- Transcript capture via Stop hook --- + captured_transcript = CapturedTranscript() + + def _on_stop(transcript_path: str, sdk_session_id: str) -> None: + captured_transcript.path = transcript_path + captured_transcript.sdk_session_id = sdk_session_id + + security_hooks = create_security_hooks( + user_id, + sdk_cwd=sdk_cwd, + max_subtasks=config.claude_agent_max_subtasks, + on_stop=_on_stop if config.claude_agent_use_resume else None, + ) + + # --- Resume strategy: download transcript from bucket --- + resume_file: str | None = None + use_resume = False + + if config.claude_agent_use_resume and user_id and len(session.messages) > 1: + transcript_content = await download_transcript(user_id, session_id) + if transcript_content and validate_transcript(transcript_content): + resume_file = write_transcript_to_tempfile( + transcript_content, session_id, sdk_cwd + ) + if resume_file: + use_resume = True + logger.info( + f"[SDK] Using --resume with transcript " + f"({len(transcript_content)} bytes)" + ) + + sdk_options_kwargs: dict[str, Any] = { + "system_prompt": system_prompt, + "mcp_servers": {"copilot": mcp_server}, + "allowed_tools": COPILOT_TOOL_NAMES, + "disallowed_tools": ["Bash"], + "hooks": security_hooks, + "cwd": sdk_cwd, + "max_buffer_size": config.claude_agent_max_buffer_size, + } + if sdk_env: + sdk_options_kwargs["model"] = sdk_model + sdk_options_kwargs["env"] = sdk_env + if use_resume and resume_file: + sdk_options_kwargs["resume"] = resume_file + + options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] + + adapter = SDKResponseAdapter(message_id=message_id) + adapter.set_task_id(task_id) + + async with ClaudeSDKClient(options=options) as client: + current_message = message or "" + if not current_message and session.messages: + last_user = [m for m in session.messages if m.role == "user"] + if last_user: + current_message = last_user[-1].content or "" + + if not current_message.strip(): + yield StreamError( + errorText="Message cannot be empty.", + code="empty_prompt", + ) + yield StreamFinish() + return + + # Build query: with --resume the CLI already has full + # context, so we only send the new message. Without + # resume, compress history into a context prefix. + query_message = current_message + if not use_resume and len(session.messages) > 1: + logger.warning( + f"[SDK] Using compression fallback for session " + f"{session_id} ({len(session.messages)} messages) — " + f"no transcript available for --resume" + ) + compressed = await _compress_conversation_history(session) + history_context = _format_conversation_context(compressed) + if history_context: + query_message = ( + f"{history_context}\n\n" + f"Now, the user says:\n{current_message}" + ) + + logger.info( + f"[SDK] Sending query ({len(session.messages)} msgs in session)" + ) + logger.debug(f"[SDK] Query preview: {current_message[:80]!r}") + await client.query(query_message, session_id=session_id) + + assistant_response = ChatMessage(role="assistant", content="") + accumulated_tool_calls: list[dict[str, Any]] = [] + has_appended_assistant = False + has_tool_results = False + + async for sdk_msg in client.receive_messages(): + logger.debug( + f"[SDK] Received: {type(sdk_msg).__name__} " + f"{getattr(sdk_msg, 'subtype', '')}" + ) + for response in adapter.convert_message(sdk_msg): + if isinstance(response, StreamStart): + continue + + yield response + + if isinstance(response, StreamTextDelta): + delta = response.delta or "" + # After tool results, start a new assistant + # message for the post-tool text. + if has_tool_results and has_appended_assistant: + assistant_response = ChatMessage( + role="assistant", content=delta + ) + accumulated_tool_calls = [] + has_appended_assistant = False + has_tool_results = False + session.messages.append(assistant_response) + has_appended_assistant = True + else: + assistant_response.content = ( + assistant_response.content or "" + ) + delta + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolInputAvailable): + accumulated_tool_calls.append( + { + "id": response.toolCallId, + "type": "function", + "function": { + "name": response.toolName, + "arguments": json.dumps(response.input or {}), + }, + } + ) + assistant_response.tool_calls = accumulated_tool_calls + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolOutputAvailable): + session.messages.append( + ChatMessage( + role="tool", + content=( + response.output + if isinstance(response.output, str) + else str(response.output) + ), + tool_call_id=response.toolCallId, + ) + ) + has_tool_results = True + + elif isinstance(response, StreamFinish): + stream_completed = True + + if stream_completed: + break + + if ( + assistant_response.content or assistant_response.tool_calls + ) and not has_appended_assistant: + session.messages.append(assistant_response) + + # --- Capture transcript while CLI is still alive --- + # Must happen INSIDE async with: close() sends SIGTERM + # which kills the CLI before it can flush the JSONL. + if ( + config.claude_agent_use_resume + and user_id + and captured_transcript.available + ): + # Give CLI time to flush JSONL writes before we read + await asyncio.sleep(0.5) + raw_transcript = read_transcript_file(captured_transcript.path) + if raw_transcript: + task = asyncio.create_task( + _upload_transcript_bg(user_id, session_id, raw_transcript) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + else: + logger.debug("[SDK] Stop hook fired but transcript not usable") + + except ImportError: + raise RuntimeError( + "claude-agent-sdk is not installed. " + "Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) " + "to use the OpenAI-compatible fallback." + ) + + await upsert_chat_session(session) + logger.debug( + f"[SDK] Session {session_id} saved with {len(session.messages)} messages" + ) + if not stream_completed: + yield StreamFinish() + + except Exception as e: + logger.error(f"[SDK] Error: {e}", exc_info=True) + try: + await upsert_chat_session(session) + except Exception as save_err: + logger.error(f"[SDK] Failed to save session on error: {save_err}") + yield StreamError( + errorText="An error occurred. Please try again.", + code="sdk_error", + ) + yield StreamFinish() + finally: + if sdk_cwd: + _cleanup_sdk_tool_results(sdk_cwd) + + +async def _upload_transcript_bg( + user_id: str, session_id: str, raw_content: str +) -> None: + """Background task to strip progress entries and upload transcript.""" + try: + await upload_transcript(user_id, session_id, raw_content) + except Exception as e: + logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}") + + +async def _update_title_async( + session_id: str, message: str, user_id: str | None = None +) -> None: + """Background task to update session title.""" + try: + title = await _generate_session_title( + message, user_id=user_id, session_id=session_id + ) + if title: + await update_session_title(session_id, title) + logger.debug(f"[SDK] Generated title for {session_id}: {title}") + except Exception as e: + logger.warning(f"[SDK] Failed to update session title: {e}") diff --git a/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py b/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py new file mode 100644 index 0000000000..4e64e77e14 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py @@ -0,0 +1,322 @@ +"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools. + +This module provides the adapter layer that converts existing BaseTool implementations +into in-process MCP tools that can be used with the Claude Agent SDK. + +Long-running tools (``is_long_running=True``) are delegated to the non-SDK +background infrastructure (stream_registry, Redis persistence, SSE reconnection) +via a callback provided by the service layer. This avoids wasteful SDK polling +and makes results survive page refreshes. +""" + +import itertools +import json +import logging +import os +import uuid +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from typing import Any + +from backend.copilot.model import ChatSession +from backend.copilot.tools import TOOL_REGISTRY +from backend.copilot.tools.base import BaseTool + +logger = logging.getLogger(__name__) + +# Allowed base directory for the Read tool (SDK saves oversized tool results here). +# Restricted to ~/.claude/projects/ and further validated to require "tool-results" +# in the path — prevents reading settings, credentials, or other sensitive files. +_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/") + +# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}" +MCP_SERVER_NAME = "copilot" +MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__" + +# Context variables to pass user/session info to tool execution +_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None) +_current_session: ContextVar[ChatSession | None] = ContextVar( + "current_session", default=None +) +# Stash for MCP tool outputs before the SDK potentially truncates them. +# Keyed by tool_name → full output string. Consumed (popped) by the +# response adapter when it builds StreamToolOutputAvailable. +_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar( + "pending_tool_outputs", default=None # type: ignore[arg-type] +) + +# Callback type for delegating long-running tools to the non-SDK infrastructure. +# Args: (tool_name, arguments, session) → MCP-formatted response dict. +LongRunningCallback = Callable[ + [str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]] +] + +# ContextVar so the service layer can inject the callback per-request. +_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar( + "long_running_callback", default=None +) + + +def set_execution_context( + user_id: str | None, + session: ChatSession, + long_running_callback: LongRunningCallback | None = None, +) -> None: + """Set the execution context for tool calls. + + This must be called before streaming begins to ensure tools have access + to user_id and session information. + + Args: + user_id: Current user's ID. + session: Current chat session. + long_running_callback: Optional callback to delegate long-running tools + to the non-SDK background infrastructure (stream_registry + Redis). + """ + _current_user_id.set(user_id) + _current_session.set(session) + _pending_tool_outputs.set({}) + _long_running_callback.set(long_running_callback) + + +def get_execution_context() -> tuple[str | None, ChatSession | None]: + """Get the current execution context.""" + return ( + _current_user_id.get(), + _current_session.get(), + ) + + +def pop_pending_tool_output(tool_name: str) -> str | None: + """Pop and return the stashed full output for *tool_name*. + + The SDK CLI may truncate large tool results (writing them to disk and + replacing the content with a file reference). This stash keeps the + original MCP output so the response adapter can forward it to the + frontend for proper widget rendering. + + Returns ``None`` if nothing was stashed for *tool_name*. + """ + pending = _pending_tool_outputs.get(None) + if pending is None: + return None + return pending.pop(tool_name, None) + + +async def _execute_tool_sync( + base_tool: BaseTool, + user_id: str | None, + session: ChatSession, + args: dict[str, Any], +) -> dict[str, Any]: + """Execute a tool synchronously and return MCP-formatted response.""" + effective_id = f"sdk-{uuid.uuid4().hex[:12]}" + result = await base_tool.execute( + user_id=user_id, + session=session, + tool_call_id=effective_id, + **args, + ) + + text = ( + result.output if isinstance(result.output, str) else json.dumps(result.output) + ) + + # Stash the full output before the SDK potentially truncates it. + pending = _pending_tool_outputs.get(None) + if pending is not None: + pending[base_tool.name] = text + + return { + "content": [{"type": "text", "text": text}], + "isError": not result.success, + } + + +def _mcp_error(message: str) -> dict[str, Any]: + return { + "content": [ + {"type": "text", "text": json.dumps({"error": message, "type": "error"})} + ], + "isError": True, + } + + +def create_tool_handler(base_tool: BaseTool): + """Create an async handler function for a BaseTool. + + This wraps the existing BaseTool._execute method to be compatible + with the Claude Agent SDK MCP tool format. + + Long-running tools (``is_long_running=True``) are delegated to the + non-SDK background infrastructure via a callback set in the execution + context. The callback persists the operation in Redis (stream_registry) + so results survive page refreshes and pod restarts. + """ + + async def tool_handler(args: dict[str, Any]) -> dict[str, Any]: + """Execute the wrapped tool and return MCP-formatted response.""" + user_id, session = get_execution_context() + + if session is None: + return _mcp_error("No session context available") + + # --- Long-running: delegate to non-SDK background infrastructure --- + if base_tool.is_long_running: + callback = _long_running_callback.get(None) + if callback: + try: + return await callback(base_tool.name, args, session) + except Exception as e: + logger.error( + f"Long-running callback failed for {base_tool.name}: {e}", + exc_info=True, + ) + return _mcp_error(f"Failed to start {base_tool.name}: {e}") + # No callback — fall through to synchronous execution + logger.warning( + f"[SDK] No long-running callback for {base_tool.name}, " + f"executing synchronously (may block)" + ) + + # --- Normal (fast) tool: execute synchronously --- + try: + return await _execute_tool_sync(base_tool, user_id, session, args) + except Exception as e: + logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True) + return _mcp_error(f"Failed to execute {base_tool.name}: {e}") + + return tool_handler + + +def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]: + """Build a JSON Schema input schema for a tool.""" + return { + "type": "object", + "properties": base_tool.parameters.get("properties", {}), + "required": base_tool.parameters.get("required", []), + } + + +async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]: + """Read a file with optional offset/limit. Restricted to SDK working directory. + + After reading, the file is deleted to prevent accumulation in long-running pods. + """ + file_path = args.get("file_path", "") + offset = args.get("offset", 0) + limit = args.get("limit", 2000) + + # Security: only allow reads under ~/.claude/projects/**/tool-results/ + real_path = os.path.realpath(file_path) + if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path: + return { + "content": [{"type": "text", "text": f"Access denied: {file_path}"}], + "isError": True, + } + + try: + with open(real_path) as f: + selected = list(itertools.islice(f, offset, offset + limit)) + content = "".join(selected) + # Cleanup happens in _cleanup_sdk_tool_results after session ends; + # don't delete here — the SDK may read in multiple chunks. + return {"content": [{"type": "text", "text": content}], "isError": False} + except FileNotFoundError: + return { + "content": [{"type": "text", "text": f"File not found: {file_path}"}], + "isError": True, + } + except Exception as e: + return { + "content": [{"type": "text", "text": f"Error reading file: {e}"}], + "isError": True, + } + + +_READ_TOOL_NAME = "Read" +_READ_TOOL_DESCRIPTION = ( + "Read a file from the local filesystem. " + "Use offset and limit to read specific line ranges for large files." +) +_READ_TOOL_SCHEMA = { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The absolute path to the file to read", + }, + "offset": { + "type": "integer", + "description": "Line number to start reading from (0-indexed). Default: 0", + }, + "limit": { + "type": "integer", + "description": "Number of lines to read. Default: 2000", + }, + }, + "required": ["file_path"], +} + + +# Create the MCP server configuration +def create_copilot_mcp_server(): + """Create an in-process MCP server configuration for CoPilot tools. + + This can be passed to ClaudeAgentOptions.mcp_servers. + + Note: The actual SDK MCP server creation depends on the claude-agent-sdk + package being available. This function returns the configuration that + can be used with the SDK. + """ + try: + from claude_agent_sdk import create_sdk_mcp_server, tool + + # Create decorated tool functions + sdk_tools = [] + + for tool_name, base_tool in TOOL_REGISTRY.items(): + handler = create_tool_handler(base_tool) + decorated = tool( + tool_name, + base_tool.description, + _build_input_schema(base_tool), + )(handler) + sdk_tools.append(decorated) + + # Add the Read tool so the SDK can read back oversized tool results + read_tool = tool( + _READ_TOOL_NAME, + _READ_TOOL_DESCRIPTION, + _READ_TOOL_SCHEMA, + )(_read_file_handler) + sdk_tools.append(read_tool) + + server = create_sdk_mcp_server( + name=MCP_SERVER_NAME, + version="1.0.0", + tools=sdk_tools, + ) + + return server + + except ImportError: + # Let ImportError propagate so service.py handles the fallback + raise + + +# SDK built-in tools allowed within the workspace directory. +# Security hooks validate that file paths stay within sdk_cwd. +# Bash is NOT included — use the sandboxed MCP bash_exec tool instead, +# which provides kernel-level network isolation via unshare --net. +# Task allows spawning sub-agents (rate-limited by security hooks). +_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"] + +# List of tool names for allowed_tools configuration +# Include MCP tools, the MCP Read tool for oversized results, +# and SDK built-in file tools for workspace operations. +COPILOT_TOOL_NAMES = [ + *[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()], + f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}", + *_SDK_BUILTIN_TOOLS, +] diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript.py b/autogpt_platform/backend/backend/copilot/sdk/transcript.py new file mode 100644 index 0000000000..aaa5609227 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript.py @@ -0,0 +1,356 @@ +"""JSONL transcript management for stateless multi-turn resume. + +The Claude Code CLI persists conversations as JSONL files (one JSON object per +line). When the SDK's ``Stop`` hook fires we read this file, strip bloat +(progress entries, metadata), and upload the result to bucket storage. On the +next turn we download the transcript, write it to a temp file, and pass +``--resume`` so the CLI can reconstruct the full conversation. + +Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local +filesystem for self-hosted) — no DB column needed. +""" + +import json +import logging +import os +import re + +logger = logging.getLogger(__name__) + +# UUIDs are hex + hyphens; strip everything else to prevent path injection. +_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]") + +# Entry types that can be safely removed from the transcript without breaking +# the parentUuid conversation tree that ``--resume`` relies on. +# - progress: UI progress ticks, no message content (avg 97KB for agent_progress) +# - file-history-snapshot: undo tracking metadata +# - queue-operation: internal queue bookkeeping +# - summary: session summaries +# - pr-link: PR link metadata +STRIPPABLE_TYPES = frozenset( + {"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"} +) + +# Workspace storage constants — deterministic path from session_id. +TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" + + +# --------------------------------------------------------------------------- +# Progress stripping +# --------------------------------------------------------------------------- + + +def strip_progress_entries(content: str) -> str: + """Remove progress/metadata entries from a JSONL transcript. + + Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents + any remaining child entries so the ``parentUuid`` chain stays intact. + Typically reduces transcript size by ~30%. + """ + lines = content.strip().split("\n") + + entries: list[dict] = [] + for line in lines: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + # Keep unparseable lines as-is (safety) + entries.append({"_raw": line}) + + stripped_uuids: set[str] = set() + uuid_to_parent: dict[str, str] = {} + kept: list[dict] = [] + + for entry in entries: + if "_raw" in entry: + kept.append(entry) + continue + uid = entry.get("uuid", "") + parent = entry.get("parentUuid", "") + entry_type = entry.get("type", "") + + if uid: + uuid_to_parent[uid] = parent + + if entry_type in STRIPPABLE_TYPES: + if uid: + stripped_uuids.add(uid) + else: + kept.append(entry) + + # Reparent: walk up chain through stripped entries to find surviving ancestor + for entry in kept: + if "_raw" in entry: + continue + parent = entry.get("parentUuid", "") + original_parent = parent + while parent in stripped_uuids: + parent = uuid_to_parent.get(parent, "") + if parent != original_parent: + entry["parentUuid"] = parent + + result_lines: list[str] = [] + for entry in kept: + if "_raw" in entry: + result_lines.append(entry["_raw"]) + else: + result_lines.append(json.dumps(entry, separators=(",", ":"))) + + return "\n".join(result_lines) + "\n" + + +# --------------------------------------------------------------------------- +# Local file I/O (read from CLI's JSONL, write temp file for --resume) +# --------------------------------------------------------------------------- + + +def read_transcript_file(transcript_path: str) -> str | None: + """Read a JSONL transcript file from disk. + + Returns the raw JSONL content, or ``None`` if the file is missing, empty, + or only contains metadata (≤2 lines with no conversation messages). + """ + if not transcript_path or not os.path.isfile(transcript_path): + logger.debug(f"[Transcript] File not found: {transcript_path}") + return None + + try: + with open(transcript_path) as f: + content = f.read() + + if not content.strip(): + logger.debug(f"[Transcript] Empty file: {transcript_path}") + return None + + lines = content.strip().split("\n") + if len(lines) < 3: + # Raw files with ≤2 lines are metadata-only + # (queue-operation + file-history-snapshot, no conversation). + logger.debug( + f"[Transcript] Too few lines ({len(lines)}): {transcript_path}" + ) + return None + + # Quick structural validation — parse first and last lines. + json.loads(lines[0]) + json.loads(lines[-1]) + + logger.info( + f"[Transcript] Read {len(lines)} lines, " + f"{len(content)} bytes from {transcript_path}" + ) + return content + + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}") + return None + + +def _sanitize_id(raw_id: str, max_len: int = 36) -> str: + """Sanitize an ID for safe use in file paths. + + Session/user IDs are expected to be UUIDs (hex + hyphens). Strip + everything else and truncate to *max_len* so the result cannot introduce + path separators or other special characters. + """ + cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len] + return cleaned or "unknown" + + +_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") + + +def write_transcript_to_tempfile( + transcript_content: str, + session_id: str, + cwd: str, +) -> str | None: + """Write JSONL transcript to a temp file inside *cwd* for ``--resume``. + + The file lives in the session working directory so it is cleaned up + automatically when the session ends. + + Returns the absolute path to the file, or ``None`` on failure. + """ + # Validate cwd is under the expected sandbox prefix (CodeQL sanitizer). + real_cwd = os.path.realpath(cwd) + if not real_cwd.startswith(_SAFE_CWD_PREFIX): + logger.warning(f"[Transcript] cwd outside sandbox: {cwd}") + return None + + try: + os.makedirs(real_cwd, exist_ok=True) + safe_id = _sanitize_id(session_id, max_len=8) + jsonl_path = os.path.realpath( + os.path.join(real_cwd, f"transcript-{safe_id}.jsonl") + ) + if not jsonl_path.startswith(real_cwd): + logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}") + return None + + with open(jsonl_path, "w") as f: + f.write(transcript_content) + + logger.info(f"[Transcript] Wrote resume file: {jsonl_path}") + return jsonl_path + + except OSError as e: + logger.warning(f"[Transcript] Failed to write resume file: {e}") + return None + + +def validate_transcript(content: str | None) -> bool: + """Check that a transcript has actual conversation messages. + + A valid transcript for resume needs at least one user message and one + assistant message (not just queue-operation / file-history-snapshot + metadata). + """ + if not content or not content.strip(): + return False + + lines = content.strip().split("\n") + if len(lines) < 2: + return False + + has_user = False + has_assistant = False + + for line in lines: + try: + entry = json.loads(line) + msg_type = entry.get("type") + if msg_type == "user": + has_user = True + elif msg_type == "assistant": + has_assistant = True + except json.JSONDecodeError: + return False + + return has_user and has_assistant + + +# --------------------------------------------------------------------------- +# Bucket storage (GCS / local via WorkspaceStorageBackend) +# --------------------------------------------------------------------------- + + +def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: + """Return (workspace_id, file_id, filename) for a session's transcript. + + Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl`` + IDs are sanitized to hex+hyphen to prevent path traversal. + """ + return ( + TRANSCRIPT_STORAGE_PREFIX, + _sanitize_id(user_id), + f"{_sanitize_id(session_id)}.jsonl", + ) + + +def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: + """Build the full storage path string that ``retrieve()`` expects. + + ``store()`` returns a path like ``gcs://bucket/workspaces/...`` or + ``local://workspace_id/file_id/filename``. Since we use deterministic + arguments we can reconstruct the same path for download/delete without + having stored the return value. + """ + from backend.util.workspace_storage import GCSWorkspaceStorage + + wid, fid, fname = _storage_path_parts(user_id, session_id) + + if isinstance(backend, GCSWorkspaceStorage): + blob = f"workspaces/{wid}/{fid}/{fname}" + return f"gcs://{backend.bucket_name}/{blob}" + else: + # LocalWorkspaceStorage returns local://{relative_path} + return f"local://{wid}/{fid}/{fname}" + + +async def upload_transcript(user_id: str, session_id: str, content: str) -> None: + """Strip progress entries and upload transcript to bucket storage. + + Safety: only overwrites when the new (stripped) transcript is larger than + what is already stored. Since JSONL is append-only, the latest transcript + is always the longest. This prevents a slow/stale background task from + clobbering a newer upload from a concurrent turn. + """ + from backend.util.workspace_storage import get_workspace_storage + + stripped = strip_progress_entries(content) + if not validate_transcript(stripped): + logger.warning( + f"[Transcript] Skipping upload — stripped content is not a valid " + f"transcript for session {session_id}" + ) + return + + storage = await get_workspace_storage() + wid, fid, fname = _storage_path_parts(user_id, session_id) + encoded = stripped.encode("utf-8") + new_size = len(encoded) + + # Check existing transcript size to avoid overwriting newer with older + path = _build_storage_path(user_id, session_id, storage) + try: + existing = await storage.retrieve(path) + if len(existing) >= new_size: + logger.info( + f"[Transcript] Skipping upload — existing transcript " + f"({len(existing)}B) >= new ({new_size}B) for session " + f"{session_id}" + ) + return + except (FileNotFoundError, Exception): + pass # No existing transcript or retrieval error — proceed with upload + + await storage.store( + workspace_id=wid, + file_id=fid, + filename=fname, + content=encoded, + ) + logger.info( + f"[Transcript] Uploaded {new_size} bytes " + f"(stripped from {len(content)}) for session {session_id}" + ) + + +async def download_transcript(user_id: str, session_id: str) -> str | None: + """Download transcript from bucket storage. + + Returns the JSONL content string, or ``None`` if not found. + """ + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + data = await storage.retrieve(path) + content = data.decode("utf-8") + logger.info( + f"[Transcript] Downloaded {len(content)} bytes for session {session_id}" + ) + return content + except FileNotFoundError: + logger.debug(f"[Transcript] No transcript in storage for {session_id}") + return None + except Exception as e: + logger.warning(f"[Transcript] Failed to download transcript: {e}") + return None + + +async def delete_transcript(user_id: str, session_id: str) -> None: + """Delete transcript from bucket storage (e.g. after resume failure).""" + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + await storage.delete(path) + logger.info(f"[Transcript] Deleted transcript for session {session_id}") + except Exception as e: + logger.warning(f"[Transcript] Failed to delete transcript: {e}") diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py new file mode 100644 index 0000000000..b4b65fd526 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py @@ -0,0 +1,255 @@ +"""Unit tests for JSONL transcript management utilities.""" + +import json +import os + +from .transcript import ( + STRIPPABLE_TYPES, + read_transcript_file, + strip_progress_entries, + validate_transcript, + write_transcript_to_tempfile, +) + + +def _make_jsonl(*entries: dict) -> str: + return "\n".join(json.dumps(e) for e in entries) + "\n" + + +# --- Fixtures --- + + +METADATA_LINE = {"type": "queue-operation", "subtype": "create"} +FILE_HISTORY = {"type": "file-history-snapshot", "files": []} +USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}} +ASST_MSG = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": {"role": "assistant", "content": "hello"}, +} +PROGRESS_ENTRY = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress", "stdout": "running..."}, +} + +VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG) + + +# --- read_transcript_file --- + + +class TestReadTranscriptFile: + def test_returns_content_for_valid_file(self, tmp_path): + path = tmp_path / "session.jsonl" + path.write_text(VALID_TRANSCRIPT) + result = read_transcript_file(str(path)) + assert result is not None + assert "user" in result + + def test_returns_none_for_missing_file(self): + assert read_transcript_file("/nonexistent/path.jsonl") is None + + def test_returns_none_for_empty_path(self): + assert read_transcript_file("") is None + + def test_returns_none_for_empty_file(self, tmp_path): + path = tmp_path / "empty.jsonl" + path.write_text("") + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_metadata_only(self, tmp_path): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + path = tmp_path / "meta.jsonl" + path.write_text(content) + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_invalid_json(self, tmp_path): + path = tmp_path / "bad.jsonl" + path.write_text("not json\n{}\n{}\n") + assert read_transcript_file(str(path)) is None + + def test_no_size_limit(self, tmp_path): + """Large files are accepted — bucket storage has no size limit.""" + big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000} + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG) + path = tmp_path / "big.jsonl" + path.write_text(content) + result = read_transcript_file(str(path)) + assert result is not None + + +# --- write_transcript_to_tempfile --- + + +class TestWriteTranscriptToTempfile: + """Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check.""" + + def test_writes_file_and_returns_path(self): + cwd = "/tmp/copilot-test-write" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "sess-1234-abcd", cwd + ) + assert result is not None + assert os.path.isfile(result) + assert result.endswith(".jsonl") + with open(result) as f: + assert f.read() == VALID_TRANSCRIPT + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_creates_parent_directory(self): + cwd = "/tmp/copilot-test-mkdir" + try: + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is not None + assert os.path.isdir(cwd) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_uses_session_id_prefix(self): + cwd = "/tmp/copilot-test-prefix" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "abcdef12-rest", cwd + ) + assert result is not None + assert "abcdef12" in os.path.basename(result) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_rejects_cwd_outside_sandbox(self, tmp_path): + cwd = str(tmp_path / "not-copilot") + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is None + + +# --- validate_transcript --- + + +class TestValidateTranscript: + def test_valid_transcript(self): + assert validate_transcript(VALID_TRANSCRIPT) is True + + def test_none_content(self): + assert validate_transcript(None) is False + + def test_empty_content(self): + assert validate_transcript("") is False + + def test_metadata_only(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + assert validate_transcript(content) is False + + def test_user_only_no_assistant(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG) + assert validate_transcript(content) is False + + def test_assistant_only_no_user(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG) + assert validate_transcript(content) is False + + def test_invalid_json_returns_false(self): + assert validate_transcript("not json\n{}\n{}\n") is False + + +# --- strip_progress_entries --- + + +class TestStripProgressEntries: + def test_strips_all_strippable_types(self): + """All STRIPPABLE_TYPES are removed from the output.""" + entries = [ + USER_MSG, + {"type": "progress", "uuid": "p1", "parentUuid": "u1"}, + {"type": "file-history-snapshot", "files": []}, + {"type": "queue-operation", "subtype": "create"}, + {"type": "summary", "text": "..."}, + {"type": "pr-link", "url": "..."}, + ASST_MSG, + ] + result = strip_progress_entries(_make_jsonl(*entries)) + result_types = {json.loads(line)["type"] for line in result.strip().split("\n")} + assert result_types == {"user", "assistant"} + for stype in STRIPPABLE_TYPES: + assert stype not in result_types + + def test_reparents_children_of_stripped_entries(self): + """An assistant message whose parent is a progress entry gets reparented.""" + progress = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress"}, + } + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p1", # Points to progress + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, progress, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + # Should be reparented to u1 (the user message) + assert asst_entry["parentUuid"] == "u1" + + def test_reparents_through_chain(self): + """Reparenting walks through multiple stripped entries.""" + p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"} + p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"} + p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"} + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p3", # 3 levels deep + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, p1, p2, p3, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + assert asst_entry["parentUuid"] == "u1" + + def test_preserves_non_strippable_entries(self): + """User, assistant, and system entries are preserved.""" + system = {"type": "system", "uuid": "s1", "message": "prompt"} + content = _make_jsonl(system, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + assert result_types == ["system", "user", "assistant"] + + def test_empty_input(self): + result = strip_progress_entries("") + # Should return just a newline (empty content stripped) + assert result.strip() == "" + + def test_no_strippable_entries(self): + """When there's nothing to strip, output matches input structure.""" + content = _make_jsonl(USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_lines = result.strip().split("\n") + assert len(result_lines) == 2 + + def test_handles_entries_without_uuid(self): + """Entries without uuid field are handled gracefully.""" + no_uuid = {"type": "queue-operation", "subtype": "create"} + content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + # queue-operation is strippable + assert "queue-operation" not in result_types + assert "user" in result_types + assert "assistant" in result_types diff --git a/autogpt_platform/backend/backend/copilot/service.py b/autogpt_platform/backend/backend/copilot/service.py index 211ccaf564..022c9063b5 100644 --- a/autogpt_platform/backend/backend/copilot/service.py +++ b/autogpt_platform/backend/backend/copilot/service.py @@ -245,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str: return DEFAULT_SYSTEM_PROMPT.format(users_information=context) -async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: +async def _build_system_prompt( + user_id: str | None, has_conversation_history: bool = False +) -> tuple[str, Any]: """Build the full system prompt including business understanding if available. Args: - user_id: The user ID for fetching business understanding - If "default" and this is the user's first session, will use "onboarding" instead. + user_id: The user ID for fetching business understanding. + has_conversation_history: Whether there's existing conversation history. + If True, we don't tell the model to greet/introduce (since they're + already in a conversation). Returns: Tuple of (compiled prompt string, business understanding object) @@ -266,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: if understanding: context = format_understanding_for_prompt(understanding) + elif has_conversation_history: + context = "No prior understanding saved yet. Continue the existing conversation naturally." else: context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" @@ -374,7 +380,6 @@ async def stream_chat_completion( Raises: NotFoundError: If session_id is invalid - ValueError: If max_context_messages is exceeded """ completion_start = time.monotonic() @@ -459,8 +464,9 @@ async def stream_chat_completion( # Generate title for new sessions on first user message (non-blocking) # Check: is_user_message, no title yet, and this is the first user message - if is_user_message and message and not session.title: - user_messages = [m for m in session.messages if m.role == "user"] + user_messages = [m for m in session.messages if m.role == "user"] + first_user_msg = message or (user_messages[0].content if user_messages else None) + if is_user_message and first_user_msg and not session.title: if len(user_messages) == 1: # First user message - generate title in background import asyncio @@ -468,7 +474,7 @@ async def stream_chat_completion( # Capture only the values we need (not the session object) to avoid # stale data issues when the main flow modifies the session captured_session_id = session_id - captured_message = message + captured_message = first_user_msg captured_user_id = user_id async def _update_title(): @@ -1237,7 +1243,7 @@ async def _stream_chat_chunks( total_time = (time_module.perf_counter() - stream_chunks_start) * 1000 logger.info( - f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; " + f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; " f"session={session.session_id}, user={session.user_id}", extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, ) diff --git a/autogpt_platform/backend/backend/copilot/service_test.py b/autogpt_platform/backend/backend/copilot/service_test.py index 70f27af14f..b2fc82b790 100644 --- a/autogpt_platform/backend/backend/copilot/service_test.py +++ b/autogpt_platform/backend/backend/copilot/service_test.py @@ -1,3 +1,4 @@ +import asyncio import logging from os import getenv @@ -11,6 +12,8 @@ from .response_model import ( StreamTextDelta, StreamToolOutputAvailable, ) +from .sdk import service as sdk_service +from .sdk.transcript import download_transcript logger = logging.getLogger(__name__) @@ -80,3 +83,96 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user session = await get_chat_session(session.session_id) assert session, "Session not found" assert session.usage, "Usage is empty" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sdk_resume_multi_turn(setup_test_user, test_user_id): + """Test that the SDK --resume path captures and uses transcripts across turns. + + Turn 1: Send a message containing a unique keyword. + Turn 2: Ask the model to recall that keyword — proving the transcript was + persisted and restored via --resume. + """ + api_key: str | None = getenv("OPEN_ROUTER_API_KEY") + if not api_key: + return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test") + + from .config import ChatConfig + + cfg = ChatConfig() + if not cfg.claude_agent_use_resume: + return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test") + + session = await create_chat_session(test_user_id) + session = await upsert_chat_session(session) + + # --- Turn 1: send a message with a unique keyword --- + keyword = "ZEPHYR42" + turn1_msg = ( + f"Please remember this special keyword: {keyword}. " + "Just confirm you've noted it, keep your response brief." + ) + turn1_text = "" + turn1_errors: list[str] = [] + turn1_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn1_msg, + user_id=test_user_id, + ): + if isinstance(chunk, StreamTextDelta): + turn1_text += chunk.delta + elif isinstance(chunk, StreamError): + turn1_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn1_ended = True + + assert turn1_ended, "Turn 1 did not finish" + assert not turn1_errors, f"Turn 1 errors: {turn1_errors}" + assert turn1_text, "Turn 1 produced no text" + + # Wait for background upload task to complete (retry up to 5s) + transcript = None + for _ in range(10): + await asyncio.sleep(0.5) + transcript = await download_transcript(test_user_id, session.session_id) + if transcript: + break + assert transcript, ( + "Transcript was not uploaded to bucket after turn 1 — " + "Stop hook may not have fired or transcript was too small" + ) + logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes") + + # Reload session for turn 2 + session = await get_chat_session(session.session_id, test_user_id) + assert session, "Session not found after turn 1" + + # --- Turn 2: ask model to recall the keyword --- + turn2_msg = "What was the special keyword I asked you to remember?" + turn2_text = "" + turn2_errors: list[str] = [] + turn2_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn2_msg, + user_id=test_user_id, + session=session, + ): + if isinstance(chunk, StreamTextDelta): + turn2_text += chunk.delta + elif isinstance(chunk, StreamError): + turn2_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn2_ended = True + + assert turn2_ended, "Turn 2 did not finish" + assert not turn2_errors, f"Turn 2 errors: {turn2_errors}" + assert turn2_text, "Turn 2 produced no text" + assert keyword in turn2_text, ( + f"Model did not recall keyword '{keyword}' in turn 2. " + f"Response: {turn2_text[:200]}" + ) + logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}") diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index abc34b1fc9..671aefc7ba 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -814,6 +814,28 @@ async def get_active_task_for_session( if task_user_id and user_id != task_user_id: continue + # Auto-expire stale tasks that exceeded stream_timeout + created_at_str = meta.get("created_at", "") + if created_at_str: + try: + created_at = datetime.fromisoformat(created_at_str) + age_seconds = ( + datetime.now(timezone.utc) - created_at + ).total_seconds() + if age_seconds > config.stream_timeout: + logger.warning( + f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... " + f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)" + ) + await mark_task_completed(task_id, "failed") + continue + except (ValueError, TypeError): + pass + + logger.info( + f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..." + ) + # Get the last message ID from Redis Stream stream_key = _get_task_stream_key(task_id) last_id = "0-0" diff --git a/autogpt_platform/backend/backend/copilot/tools/__init__.py b/autogpt_platform/backend/backend/copilot/tools/__init__.py index e24c927112..0593fe69c0 100644 --- a/autogpt_platform/backend/backend/copilot/tools/__init__.py +++ b/autogpt_platform/backend/backend/copilot/tools/__init__.py @@ -9,9 +9,12 @@ from backend.copilot.tracking import track_tool_called from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool +from .bash_exec import BashExecTool +from .check_operation_status import CheckOperationStatusTool from .create_agent import CreateAgentTool from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool +from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool from .find_agent import FindAgentTool from .find_block import FindBlockTool from .find_library_agent import FindLibraryAgentTool @@ -19,6 +22,7 @@ from .get_doc_page import GetDocPageTool from .run_agent import RunAgentTool from .run_block import RunBlockTool from .search_docs import SearchDocsTool +from .web_fetch import WebFetchTool from .workspace_files import ( DeleteWorkspaceFileTool, ListWorkspaceFilesTool, @@ -43,8 +47,17 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "run_agent": RunAgentTool(), "run_block": RunBlockTool(), "view_agent_output": AgentOutputTool(), + "check_operation_status": CheckOperationStatusTool(), "search_docs": SearchDocsTool(), "get_doc_page": GetDocPageTool(), + # Web fetch for safe URL retrieval + "web_fetch": WebFetchTool(), + # Sandboxed code execution (bubblewrap) + "bash_exec": BashExecTool(), + # Persistent workspace tools (cloud storage, survives across sessions) + # Feature request tools + "search_feature_requests": SearchFeatureRequestsTool(), + "create_feature_request": CreateFeatureRequestTool(), # Workspace tools for CoPilot file operations "list_workspace_files": ListWorkspaceFilesTool(), "read_workspace_file": ReadWorkspaceFileTool(), diff --git a/autogpt_platform/backend/backend/copilot/tools/bash_exec.py b/autogpt_platform/backend/backend/copilot/tools/bash_exec.py new file mode 100644 index 0000000000..6e32a3c720 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/bash_exec.py @@ -0,0 +1,124 @@ +"""Bash execution tool — run shell commands in a bubblewrap sandbox. + +Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.). +Safety comes from OS-level isolation (bubblewrap): only system dirs visible +read-only, writable workspace only, clean env, no network. + +Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not +available (e.g. macOS development). +""" + +import logging +from typing import Any + +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import BashExecResponse, ErrorResponse, ToolResponseBase +from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed + +logger = logging.getLogger(__name__) + + +class BashExecTool(BaseTool): + """Execute Bash commands in a bubblewrap sandbox.""" + + @property + def name(self) -> str: + return "bash_exec" + + @property + def description(self) -> str: + if not has_full_sandbox(): + return ( + "Bash execution is DISABLED — bubblewrap sandbox is not " + "available on this platform. Do not call this tool." + ) + return ( + "Execute a Bash command or script in a bubblewrap sandbox. " + "Full Bash scripting is supported (loops, conditionals, pipes, " + "functions, etc.). " + "The sandbox shares the same working directory as the SDK Read/Write " + "tools — files created by either are accessible to both. " + "SECURITY: Only system directories (/usr, /bin, /lib, /etc) are " + "visible read-only, the per-session workspace is the only writable " + "path, environment variables are wiped (no secrets), all network " + "access is blocked at the kernel level, and resource limits are " + "enforced (max 64 processes, 512MB memory, 50MB file size). " + "Application code, configs, and other directories are NOT accessible. " + "To fetch web content, use the web_fetch tool instead. " + "Execution is killed after the timeout (default 30s, max 120s). " + "Returns stdout and stderr. " + "Useful for file manipulation, data processing with Unix tools " + "(grep, awk, sed, jq, etc.), and running shell scripts." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash command or script to execute.", + }, + "timeout": { + "type": "integer", + "description": ( + "Max execution time in seconds (default 30, max 120)." + ), + "default": 30, + }, + }, + "required": ["command"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + session_id = session.session_id if session else None + + if not has_full_sandbox(): + return ErrorResponse( + message="bash_exec requires bubblewrap sandbox (Linux only).", + error="sandbox_unavailable", + session_id=session_id, + ) + + command: str = (kwargs.get("command") or "").strip() + timeout: int = kwargs.get("timeout", 30) + + if not command: + return ErrorResponse( + message="No command provided.", + error="empty_command", + session_id=session_id, + ) + + workspace = get_workspace_dir(session_id or "default") + + stdout, stderr, exit_code, timed_out = await run_sandboxed( + command=["bash", "-c", command], + cwd=workspace, + timeout=timeout, + ) + + return BashExecResponse( + message=( + "Execution timed out" + if timed_out + else f"Command executed (exit {exit_code})" + ), + stdout=stdout, + stderr=stderr, + exit_code=exit_code, + timed_out=timed_out, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/copilot/tools/check_operation_status.py b/autogpt_platform/backend/backend/copilot/tools/check_operation_status.py new file mode 100644 index 0000000000..a03fe074ba --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/check_operation_status.py @@ -0,0 +1,124 @@ +"""CheckOperationStatusTool — query the status of a long-running operation.""" + +import logging +from typing import Any + +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import ErrorResponse, ResponseType, ToolResponseBase + +logger = logging.getLogger(__name__) + + +class OperationStatusResponse(ToolResponseBase): + """Response for check_operation_status tool.""" + + type: ResponseType = ResponseType.OPERATION_STATUS + task_id: str + operation_id: str + status: str # "running", "completed", "failed" + tool_name: str | None = None + message: str = "" + + +class CheckOperationStatusTool(BaseTool): + """Check the status of a long-running operation (create_agent, edit_agent, etc.). + + The CoPilot uses this tool to report back to the user whether an + operation that was started earlier has completed, failed, or is still + running. + """ + + @property + def name(self) -> str: + return "check_operation_status" + + @property + def description(self) -> str: + return ( + "Check the current status of a long-running operation such as " + "create_agent or edit_agent. Accepts either an operation_id or " + "task_id from a previous operation_started response. " + "Returns the current status: running, completed, or failed." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "operation_id": { + "type": "string", + "description": ( + "The operation_id from an operation_started response." + ), + }, + "task_id": { + "type": "string", + "description": ( + "The task_id from an operation_started response. " + "Used as fallback if operation_id is not provided." + ), + }, + }, + "required": [], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + from backend.copilot import stream_registry + + operation_id = (kwargs.get("operation_id") or "").strip() + task_id = (kwargs.get("task_id") or "").strip() + + if not operation_id and not task_id: + return ErrorResponse( + message="Please provide an operation_id or task_id.", + error="missing_parameter", + ) + + task = None + if operation_id: + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None and task_id: + task = await stream_registry.get_task(task_id) + + if task is None: + # Task not in Redis — it may have already expired (TTL). + # Check conversation history for the result instead. + return ErrorResponse( + message=( + "Operation not found — it may have already completed and " + "expired from the status tracker. Check the conversation " + "history for the result." + ), + error="not_found", + ) + + status_messages = { + "running": ( + f"The {task.tool_name or 'operation'} is still running. " + "Please wait for it to complete." + ), + "completed": ( + f"The {task.tool_name or 'operation'} has completed successfully." + ), + "failed": f"The {task.tool_name or 'operation'} has failed.", + } + + return OperationStatusResponse( + task_id=task.task_id, + operation_id=task.operation_id, + status=task.status, + tool_name=task.tool_name, + message=status_messages.get(task.status, f"Status: {task.status}"), + ) diff --git a/autogpt_platform/backend/backend/copilot/tools/feature_requests.py b/autogpt_platform/backend/backend/copilot/tools/feature_requests.py new file mode 100644 index 0000000000..ebfc37f475 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/feature_requests.py @@ -0,0 +1,449 @@ +"""Feature request tools - search and create feature requests via Linear.""" + +import logging +from typing import Any + +from pydantic import SecretStr + +from backend.blocks.linear._api import LinearClient +from backend.copilot.model import ChatSession +from backend.data.model import APIKeyCredentials +from backend.data.user import get_user_email_by_id +from backend.util.settings import Settings + +from .base import BaseTool +from .models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestInfo, + FeatureRequestSearchResponse, + NoResultsResponse, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + +MAX_SEARCH_RESULTS = 10 + +# GraphQL queries/mutations +SEARCH_ISSUES_QUERY = """ +query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) { + searchIssues(term: $term, filter: $filter, first: $first) { + nodes { + id + identifier + title + description + } + } +} +""" + +CUSTOMER_UPSERT_MUTATION = """ +mutation CustomerUpsert($input: CustomerUpsertInput!) { + customerUpsert(input: $input) { + success + customer { + id + name + externalIds + } + } +} +""" + +ISSUE_CREATE_MUTATION = """ +mutation IssueCreate($input: IssueCreateInput!) { + issueCreate(input: $input) { + success + issue { + id + identifier + title + url + } + } +} +""" + +CUSTOMER_NEED_CREATE_MUTATION = """ +mutation CustomerNeedCreate($input: CustomerNeedCreateInput!) { + customerNeedCreate(input: $input) { + success + need { + id + body + customer { + id + name + } + issue { + id + identifier + title + url + } + } + } +} +""" + + +_settings: Settings | None = None + + +def _get_settings() -> Settings: + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def _get_linear_config() -> tuple[LinearClient, str, str]: + """Return a configured Linear client, project ID, and team ID. + + Raises RuntimeError if any required setting is missing. + """ + secrets = _get_settings().secrets + if not secrets.linear_api_key: + raise RuntimeError("LINEAR_API_KEY is not configured") + if not secrets.linear_feature_request_project_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured") + if not secrets.linear_feature_request_team_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_TEAM_ID is not configured") + + credentials = APIKeyCredentials( + id="system-linear", + provider="linear", + api_key=SecretStr(secrets.linear_api_key), + title="System Linear API Key", + ) + client = LinearClient(credentials=credentials) + return ( + client, + secrets.linear_feature_request_project_id, + secrets.linear_feature_request_team_id, + ) + + +class SearchFeatureRequestsTool(BaseTool): + """Tool for searching existing feature requests in Linear.""" + + @property + def name(self) -> str: + return "search_feature_requests" + + @property + def description(self) -> str: + return ( + "Search existing feature requests to check if a similar request " + "already exists before creating a new one. Returns matching feature " + "requests with their ID, title, and description." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search term to find matching feature requests.", + }, + }, + "required": ["query"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + query = kwargs.get("query", "").strip() + session_id = session.session_id if session else None + + if not query: + return ErrorResponse( + message="Please provide a search query.", + error="Missing query parameter", + session_id=session_id, + ) + + try: + client, project_id, _team_id = _get_linear_config() + data = await client.query( + SEARCH_ISSUES_QUERY, + { + "term": query, + "filter": { + "project": {"id": {"eq": project_id}}, + }, + "first": MAX_SEARCH_RESULTS, + }, + ) + + nodes = data.get("searchIssues", {}).get("nodes", []) + + if not nodes: + return NoResultsResponse( + message=f"No feature requests found matching '{query}'.", + suggestions=[ + "Try different keywords", + "Use broader search terms", + "You can create a new feature request if none exists", + ], + session_id=session_id, + ) + + results = [ + FeatureRequestInfo( + id=node["id"], + identifier=node["identifier"], + title=node["title"], + description=node.get("description"), + ) + for node in nodes + ] + + return FeatureRequestSearchResponse( + message=f"Found {len(results)} feature request(s) matching '{query}'.", + results=results, + count=len(results), + query=query, + session_id=session_id, + ) + except Exception as e: + logger.exception("Failed to search feature requests") + return ErrorResponse( + message="Failed to search feature requests.", + error=str(e), + session_id=session_id, + ) + + +class CreateFeatureRequestTool(BaseTool): + """Tool for creating feature requests (or adding needs to existing ones).""" + + @property + def name(self) -> str: + return "create_feature_request" + + @property + def description(self) -> str: + return ( + "Create a new feature request or add a customer need to an existing one. " + "Always search first with search_feature_requests to avoid duplicates. " + "If a matching request exists, pass its ID as existing_issue_id to add " + "the user's need to it instead of creating a duplicate." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title for the feature request.", + }, + "description": { + "type": "string", + "description": "Detailed description of what the user wants and why.", + }, + "existing_issue_id": { + "type": "string", + "description": ( + "If adding a need to an existing feature request, " + "provide its Linear issue ID (from search results). " + "Omit to create a new feature request." + ), + }, + }, + "required": ["title", "description"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _find_or_create_customer( + self, client: LinearClient, user_id: str, name: str + ) -> dict: + """Find existing customer by user_id or create a new one via upsert. + + Args: + client: Linear API client. + user_id: Stable external ID used to deduplicate customers. + name: Human-readable display name (e.g. the user's email). + """ + data = await client.mutate( + CUSTOMER_UPSERT_MUTATION, + { + "input": { + "name": name, + "externalId": user_id, + }, + }, + ) + result = data.get("customerUpsert", {}) + if not result.get("success"): + raise RuntimeError(f"Failed to upsert customer: {data}") + return result["customer"] + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + title = kwargs.get("title", "").strip() + description = kwargs.get("description", "").strip() + existing_issue_id = kwargs.get("existing_issue_id") + session_id = session.session_id if session else None + + if not title or not description: + return ErrorResponse( + message="Both title and description are required.", + error="Missing required parameters", + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="Authentication required to create feature requests.", + error="Missing user_id", + session_id=session_id, + ) + + try: + client, project_id, team_id = _get_linear_config() + except Exception as e: + logger.exception("Failed to initialize Linear client") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Resolve a human-readable name (email) for the Linear customer record. + # Fall back to user_id if the lookup fails or returns None. + try: + customer_display_name = await get_user_email_by_id(user_id) or user_id + except Exception: + customer_display_name = user_id + + # Step 1: Find or create customer for this user + try: + customer = await self._find_or_create_customer( + client, user_id, customer_display_name + ) + customer_id = customer["id"] + customer_name = customer["name"] + except Exception as e: + logger.exception("Failed to upsert customer in Linear") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Step 2: Create or reuse issue + issue_id: str | None = None + issue_identifier: str | None = None + if existing_issue_id: + # Add need to existing issue - we still need the issue details for response + is_new_issue = False + issue_id = existing_issue_id + else: + # Create new issue in the feature requests project + try: + data = await client.mutate( + ISSUE_CREATE_MUTATION, + { + "input": { + "title": title, + "description": description, + "teamId": team_id, + "projectId": project_id, + }, + }, + ) + result = data.get("issueCreate", {}) + if not result.get("success"): + return ErrorResponse( + message="Failed to create feature request issue.", + error=str(data), + session_id=session_id, + ) + issue = result["issue"] + issue_id = issue["id"] + issue_identifier = issue.get("identifier") + except Exception as e: + logger.exception("Failed to create feature request issue") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + is_new_issue = True + + # Step 3: Create customer need on the issue + try: + data = await client.mutate( + CUSTOMER_NEED_CREATE_MUTATION, + { + "input": { + "customerId": customer_id, + "issueId": issue_id, + "body": description, + "priority": 0, + }, + }, + ) + need_result = data.get("customerNeedCreate", {}) + if not need_result.get("success"): + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(data), + details=orphaned, + session_id=session_id, + ) + need = need_result["need"] + issue_info = need["issue"] + except Exception as e: + logger.exception("Failed to create customer need") + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(e), + details=orphaned, + session_id=session_id, + ) + + return FeatureRequestCreatedResponse( + message=( + f"{'Created new feature request' if is_new_issue else 'Added your request to existing feature request'}: " + f"{issue_info['title']}." + ), + issue_id=issue_info["id"], + issue_identifier=issue_info["identifier"], + issue_title=issue_info["title"], + issue_url=issue_info.get("url", ""), + is_new_issue=is_new_issue, + customer_name=customer_name, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/copilot/tools/feature_requests_test.py b/autogpt_platform/backend/backend/copilot/tools/feature_requests_test.py new file mode 100644 index 0000000000..9e8104d90d --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/feature_requests_test.py @@ -0,0 +1,611 @@ +"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from ._test_data import make_session +from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool +from .models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestSearchResponse, + NoResultsResponse, +) + +_TEST_USER_ID = "test-user-feature-requests" +_TEST_USER_EMAIL = "testuser@example.com" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +_FAKE_PROJECT_ID = "test-project-id" +_FAKE_TEAM_ID = "test-team-id" + + +def _mock_linear_config(*, query_return=None, mutate_return=None): + """Return a patched _get_linear_config that yields a mock LinearClient.""" + client = AsyncMock() + if query_return is not None: + client.query.return_value = query_return + if mutate_return is not None: + client.mutate.return_value = mutate_return + return ( + patch( + "backend.copilot.tools.feature_requests._get_linear_config", + return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID), + ), + client, + ) + + +def _search_response(nodes: list[dict]) -> dict: + return {"searchIssues": {"nodes": nodes}} + + +def _customer_upsert_response( + customer_id: str = "cust-1", name: str = _TEST_USER_EMAIL, success: bool = True +) -> dict: + return { + "customerUpsert": { + "success": success, + "customer": {"id": customer_id, "name": name, "externalIds": [name]}, + } + } + + +def _issue_create_response( + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "issueCreate": { + "success": success, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + } + } + + +def _need_create_response( + need_id: str = "need-1", + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "customerNeedCreate": { + "success": success, + "need": { + "id": need_id, + "body": "description", + "customer": {"id": "cust-1", "name": _TEST_USER_EMAIL}, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + }, + } + } + + +# =========================================================================== +# SearchFeatureRequestsTool +# =========================================================================== + + +class TestSearchFeatureRequestsTool: + """Tests for SearchFeatureRequestsTool._execute.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_successful_search(self): + session = make_session(user_id=_TEST_USER_ID) + nodes = [ + { + "id": "id-1", + "identifier": "FR-1", + "title": "Dark mode", + "description": "Add dark mode support", + }, + { + "id": "id-2", + "identifier": "FR-2", + "title": "Dark theme", + "description": None, + }, + ] + patcher, _ = _mock_linear_config(query_return=_search_response(nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="dark mode" + ) + + assert isinstance(resp, FeatureRequestSearchResponse) + assert resp.count == 2 + assert resp.results[0].id == "id-1" + assert resp.results[1].identifier == "FR-2" + assert resp.query == "dark mode" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_results(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, _ = _mock_linear_config(query_return=_search_response([])) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="nonexistent" + ) + + assert isinstance(resp, NoResultsResponse) + assert "nonexistent" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_empty_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session, query=" ") + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "query" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_api_failure(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.query.side_effect = RuntimeError("Linear API down") + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Linear API down" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_malformed_node_returns_error(self): + """A node missing required keys should be caught by the try/except.""" + session = make_session(user_id=_TEST_USER_ID) + # Node missing 'identifier' key + bad_nodes = [{"id": "id-1", "title": "Missing identifier"}] + patcher, _ = _mock_linear_config(query_return=_search_response(bad_nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.copilot.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + +# =========================================================================== +# CreateFeatureRequestTool +# =========================================================================== + + +class TestCreateFeatureRequestTool: + """Tests for CreateFeatureRequestTool._execute.""" + + @pytest.fixture(autouse=True) + def _patch_email_lookup(self): + with patch( + "backend.copilot.tools.feature_requests.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TEST_USER_EMAIL, + ): + yield + + # ---- Happy paths ------------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_create_new_issue(self): + """Full happy path: upsert customer -> create issue -> attach need.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + _need_create_response(), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="New Feature", + description="Please add this", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is True + assert resp.issue_identifier == "FR-1" + assert resp.customer_name == _TEST_USER_EMAIL + assert client.mutate.call_count == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_add_need_to_existing_issue(self): + """When existing_issue_id is provided, skip issue creation.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(issue_id="existing-1", identifier="FR-99"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Existing Feature", + description="Me too", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is False + assert resp.issue_id == "existing-1" + # Only 2 mutations: customer upsert + need create (no issue create) + assert client.mutate.call_count == 2 + + # ---- Validation errors ------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_title(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="", + description="some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_description(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Some title", + description="", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_user_id(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=None, + session=session, + title="Some title", + description="Some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "user_id" in resp.error.lower() + + # ---- Linear client init failure ---------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.copilot.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + # ---- Customer upsert failures ------------------------------------------ + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = RuntimeError("Customer API error") + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Customer API error" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.return_value = _customer_upsert_response(success=False) + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_malformed_response(self): + """Customer dict missing 'id' key should be caught.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + # success=True but customer has no 'id' + client.mutate.return_value = { + "customerUpsert": { + "success": True, + "customer": {"name": _TEST_USER_ID}, + } + } + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Issue creation failures ------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Issue create failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Issue create failed" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert "Failed to create feature request issue" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_malformed_response(self): + """issueCreate success=True but missing 'issue' key.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + {"issueCreate": {"success": True}}, # no 'issue' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Customer need attachment failures --------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_new_issue(self): + """Need creation fails after new issue was created -> orphaned issue info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-1", identifier="FR-10"), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Need attach failed" in resp.error + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-1" + assert resp.details["issue_identifier"] == "FR-10" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_existing_issue(self): + """Need creation fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_includes_orphaned_info(self): + """customerNeedCreate returns success=False -> includes orphaned issue.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-2", identifier="FR-20"), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-2" + assert resp.details["issue_identifier"] == "FR-20" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_existing_issue_no_details(self): + """customerNeedCreate fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_malformed_response(self): + """need_result missing 'need' key after success=True.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + {"customerNeedCreate": {"success": True}}, # no 'need' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "issue-1" diff --git a/autogpt_platform/backend/backend/copilot/tools/find_block.py b/autogpt_platform/backend/backend/copilot/tools/find_block.py index 0d179c7fc1..a3f784f3a8 100644 --- a/autogpt_platform/backend/backend/copilot/tools/find_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block.py @@ -6,14 +6,15 @@ from prisma.enums import ContentType from backend.blocks import get_block from backend.blocks._base import BlockType from backend.copilot.model import ChatSession -from backend.copilot.tools.base import BaseTool, ToolResponseBase -from backend.copilot.tools.models import ( +from backend.data.db_accessors import search + +from .base import BaseTool, ToolResponseBase +from .models import ( BlockInfoSummary, BlockListResponse, ErrorResponse, NoResultsResponse, ) -from backend.data.db_accessors import search logger = logging.getLogger(__name__) @@ -146,6 +147,7 @@ class FindBlockTool(BaseTool): id=block_id, name=block.name, description=block.description or "", + categories=[c.value for c in block.categories], ) ) diff --git a/autogpt_platform/backend/backend/copilot/tools/find_block_test.py b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py index 414bbdc6f0..ebd3c761ab 100644 --- a/autogpt_platform/backend/backend/copilot/tools/find_block_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py @@ -5,14 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.blocks._base import BlockType -from backend.copilot.tools.find_block import ( + +from ._test_data import make_session +from .find_block import ( COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES, FindBlockTool, ) -from backend.copilot.tools.models import BlockListResponse - -from ._test_data import make_session +from .models import BlockListResponse _TEST_USER_ID = "test-user-find-block" diff --git a/autogpt_platform/backend/backend/copilot/tools/get_doc_page.py b/autogpt_platform/backend/backend/copilot/tools/get_doc_page.py index c923a133c5..87ec7225a5 100644 --- a/autogpt_platform/backend/backend/copilot/tools/get_doc_page.py +++ b/autogpt_platform/backend/backend/copilot/tools/get_doc_page.py @@ -5,12 +5,9 @@ from pathlib import Path from typing import Any from backend.copilot.model import ChatSession -from backend.copilot.tools.base import BaseTool -from backend.copilot.tools.models import ( - DocPageResponse, - ErrorResponse, - ToolResponseBase, -) + +from .base import BaseTool +from .models import DocPageResponse, ErrorResponse, ToolResponseBase logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/copilot/tools/models.py b/autogpt_platform/backend/backend/copilot/tools/models.py index bd19d590a6..b32f6ca2ce 100644 --- a/autogpt_platform/backend/backend/copilot/tools/models.py +++ b/autogpt_platform/backend/backend/copilot/tools/models.py @@ -41,6 +41,15 @@ class ResponseType(str, Enum): OPERATION_IN_PROGRESS = "operation_in_progress" # Input validation INPUT_VALIDATION_ERROR = "input_validation_error" + # Web fetch + WEB_FETCH = "web_fetch" + # Code execution + BASH_EXEC = "bash_exec" + # Operation status check + OPERATION_STATUS = "operation_status" + # Feature request types + FEATURE_REQUEST_SEARCH = "feature_request_search" + FEATURE_REQUEST_CREATED = "feature_request_created" # Base response model @@ -335,6 +344,19 @@ class BlockInfoSummary(BaseModel): id: str name: str description: str + categories: list[str] + input_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block inputs", + ) + output_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block outputs", + ) + required_inputs: list[BlockInputFieldInfo] = Field( + default_factory=list, + description="List of input fields for this block", + ) class BlockListResponse(ToolResponseBase): @@ -344,6 +366,10 @@ class BlockListResponse(ToolResponseBase): blocks: list[BlockInfoSummary] count: int query: str + usage_hint: str = Field( + default="To execute a block, call run_block with block_id set to the block's " + "'id' field and input_data containing the fields listed in required_inputs." + ) class BlockDetails(BaseModel): @@ -430,3 +456,55 @@ class AsyncProcessingResponse(ToolResponseBase): status: str = "accepted" # Must be "accepted" for detection operation_id: str | None = None task_id: str | None = None + + +class WebFetchResponse(ToolResponseBase): + """Response for web_fetch tool.""" + + type: ResponseType = ResponseType.WEB_FETCH + url: str + status_code: int + content_type: str + content: str + truncated: bool = False + + +class BashExecResponse(ToolResponseBase): + """Response for bash_exec tool.""" + + type: ResponseType = ResponseType.BASH_EXEC + stdout: str + stderr: str + exit_code: int + timed_out: bool = False + + +# Feature request models +class FeatureRequestInfo(BaseModel): + """Information about a feature request issue.""" + + id: str + identifier: str + title: str + description: str | None = None + + +class FeatureRequestSearchResponse(ToolResponseBase): + """Response for search_feature_requests tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH + results: list[FeatureRequestInfo] + count: int + query: str + + +class FeatureRequestCreatedResponse(ToolResponseBase): + """Response for create_feature_request tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED + issue_id: str + issue_identifier: str + issue_title: str + issue_url: str + is_new_issue: bool # False if added to existing + customer_name: str diff --git a/autogpt_platform/backend/backend/copilot/tools/run_block.py b/autogpt_platform/backend/backend/copilot/tools/run_block.py index 8e899e6960..32f249626b 100644 --- a/autogpt_platform/backend/backend/copilot/tools/run_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_block.py @@ -10,10 +10,6 @@ from pydantic_core import PydanticUndefined from backend.blocks import get_block from backend.blocks._base import AnyBlockSchema from backend.copilot.model import ChatSession -from backend.copilot.tools.find_block import ( - COPILOT_EXCLUDED_BLOCK_IDS, - COPILOT_EXCLUDED_BLOCK_TYPES, -) from backend.data.db_accessors import workspace_db from backend.data.execution import ExecutionContext from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput @@ -21,6 +17,7 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError from .base import BaseTool +from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES from .helpers import get_inputs_from_schema from .models import ( BlockDetails, diff --git a/autogpt_platform/backend/backend/copilot/tools/run_block_test.py b/autogpt_platform/backend/backend/copilot/tools/run_block_test.py index b13a339127..7ab4d706a2 100644 --- a/autogpt_platform/backend/backend/copilot/tools/run_block_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_block_test.py @@ -5,15 +5,15 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.blocks._base import BlockType -from backend.copilot.tools.models import ( + +from ._test_data import make_session +from .models import ( BlockDetailsResponse, BlockOutputResponse, ErrorResponse, InputValidationErrorResponse, ) -from backend.copilot.tools.run_block import RunBlockTool - -from ._test_data import make_session +from .run_block import RunBlockTool _TEST_USER_ID = "test-user-run-block" diff --git a/autogpt_platform/backend/backend/copilot/tools/sandbox.py b/autogpt_platform/backend/backend/copilot/tools/sandbox.py new file mode 100644 index 0000000000..beb326f909 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/sandbox.py @@ -0,0 +1,265 @@ +"""Sandbox execution utilities for code execution tools. + +Provides filesystem + network isolated command execution using **bubblewrap** +(``bwrap``): whitelist-only filesystem (only system dirs visible read-only), +writable workspace only, clean environment, network blocked. + +Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox` +and refuse to run if bubblewrap is not available. +""" + +import asyncio +import logging +import os +import platform +import shutil + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT = 30 +_MAX_TIMEOUT = 120 + + +# --------------------------------------------------------------------------- +# Sandbox capability detection (cached at first call) +# --------------------------------------------------------------------------- + +_BWRAP_AVAILABLE: bool | None = None + + +def has_full_sandbox() -> bool: + """Return True if bubblewrap is available (filesystem + network isolation). + + On non-Linux platforms (macOS), always returns False. + """ + global _BWRAP_AVAILABLE + if _BWRAP_AVAILABLE is None: + _BWRAP_AVAILABLE = ( + platform.system() == "Linux" and shutil.which("bwrap") is not None + ) + return _BWRAP_AVAILABLE + + +WORKSPACE_PREFIX = "/tmp/copilot-" + + +def make_session_path(session_id: str) -> str: + """Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`. + + Shared by both the SDK working-directory setup and the sandbox tools so + they always resolve to the same directory for a given session. + + Steps: + 1. Strip all characters except ``[A-Za-z0-9-]``. + 2. Construct ``/tmp/copilot-``. + 3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised + sanitizer) to prevent path traversal. + + Raises: + ValueError: If the resulting path escapes the prefix. + """ + import re + + safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id) + if not safe_id: + safe_id = "default" + path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}") + if not path.startswith(WORKSPACE_PREFIX): + raise ValueError(f"Session path escaped prefix: {path}") + return path + + +def get_workspace_dir(session_id: str) -> str: + """Get or create the workspace directory for a session. + + Uses :func:`make_session_path` — the same path the SDK uses — so that + bash_exec shares the workspace with the SDK file tools. + """ + workspace = make_session_path(session_id) + os.makedirs(workspace, exist_ok=True) + return workspace + + +# --------------------------------------------------------------------------- +# Bubblewrap command builder +# --------------------------------------------------------------------------- + +# System directories mounted read-only inside the sandbox. +# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible. +_SYSTEM_RO_BINDS = [ + "/usr", # binaries, libraries, Python interpreter + "/etc", # system config: ld.so, locale, passwd, alternatives +] + +# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems. +# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind +# can't create a symlink target, so we detect and use --symlink instead. +# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2. +_COMPAT_PATHS = [ + ("/bin", "usr/bin"), # -> /usr/bin on Debian 13 + ("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13 + ("/lib", "usr/lib"), # -> /usr/lib on Debian 13 + ("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter +] + +# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse. +# Applied via ulimit inside the sandbox before exec'ing the user command. +_RESOURCE_LIMITS = ( + "ulimit -u 64" # max 64 processes (prevents fork bombs) + " -v 524288" # 512 MB virtual memory + " -f 51200" # 50 MB max file size (1024-byte blocks) + " -n 256" # 256 open file descriptors + " 2>/dev/null" +) + + +def _build_bwrap_command( + command: list[str], cwd: str, env: dict[str, str] +) -> list[str]: + """Build a bubblewrap command with strict filesystem + network isolation. + + Security model: + - **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``, + ``/bin``, ``/lib``) are mounted read-only. Application code (``/app``), + home directories, ``/var``, ``/opt``, etc. are NOT accessible at all. + - **Writable workspace only**: the per-session workspace is the sole + writable path. + - **Clean environment**: ``--clearenv`` wipes all inherited env vars. + Only the explicitly-passed safe env vars are set inside the sandbox. + - **Network isolation**: ``--unshare-net`` blocks all network access. + - **Resource limits**: ulimit caps on processes (64), memory (512MB), + file size (50MB), and open FDs (256) to prevent fork bombs and abuse. + - **New session**: prevents terminal control escape. + - **Die with parent**: prevents orphaned sandbox processes. + """ + cmd = [ + "bwrap", + # Create a new user namespace so bwrap can set up sandboxing + # inside unprivileged Docker containers (no CAP_SYS_ADMIN needed). + "--unshare-user", + # Wipe all inherited environment variables (API keys, secrets, etc.) + "--clearenv", + ] + + # Set only the safe env vars inside the sandbox + for key, value in env.items(): + cmd.extend(["--setenv", key, value]) + + # System directories: read-only + for path in _SYSTEM_RO_BINDS: + cmd.extend(["--ro-bind", path, path]) + + # Compat paths: use --symlink when host path is a symlink (Debian 13), + # --ro-bind when it's a real directory (older distros). + for path, symlink_target in _COMPAT_PATHS: + if os.path.islink(path): + cmd.extend(["--symlink", symlink_target, path]) + elif os.path.exists(path): + cmd.extend(["--ro-bind", path, path]) + + # Wrap the user command with resource limits: + # sh -c 'ulimit ...; exec "$@"' -- + # `exec "$@"` replaces the shell so there's no extra process overhead, + # and properly handles arguments with spaces. + limited_command = [ + "sh", + "-c", + f'{_RESOURCE_LIMITS}; exec "$@"', + "--", + *command, + ] + + cmd.extend( + [ + # Fresh virtual filesystems + "--dev", + "/dev", + "--proc", + "/proc", + "--tmpfs", + "/tmp", + # Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs. + # (workspace lives under /tmp/copilot-) + "--bind", + cwd, + cwd, + # Isolation + "--unshare-net", + "--die-with-parent", + "--new-session", + "--chdir", + cwd, + "--", + *limited_command, + ] + ) + + return cmd + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def run_sandboxed( + command: list[str], + cwd: str, + timeout: int = _DEFAULT_TIMEOUT, + env: dict[str, str] | None = None, +) -> tuple[str, str, int, bool]: + """Run a command inside a bubblewrap sandbox. + + Callers **must** check :func:`has_full_sandbox` before calling this + function. If bubblewrap is not available, this function raises + :class:`RuntimeError` rather than running unsandboxed. + + Returns: + (stdout, stderr, exit_code, timed_out) + """ + if not has_full_sandbox(): + raise RuntimeError( + "run_sandboxed() requires bubblewrap but bwrap is not available. " + "Callers must check has_full_sandbox() before calling this function." + ) + + timeout = min(max(timeout, 1), _MAX_TIMEOUT) + + safe_env = { + "PATH": "/usr/local/bin:/usr/bin:/bin", + "HOME": cwd, + "TMPDIR": cwd, + "LANG": "en_US.UTF-8", + "PYTHONDONTWRITEBYTECODE": "1", + "PYTHONIOENCODING": "utf-8", + } + if env: + safe_env.update(env) + + full_command = _build_bwrap_command(command, cwd, safe_env) + + try: + proc = await asyncio.create_subprocess_exec( + *full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=safe_env, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + stdout = stdout_bytes.decode("utf-8", errors="replace") + stderr = stderr_bytes.decode("utf-8", errors="replace") + return stdout, stderr, proc.returncode or 0, False + except asyncio.TimeoutError: + proc.kill() + await proc.communicate() + return "", f"Execution timed out after {timeout}s", -1, True + + except RuntimeError: + raise + except Exception as e: + return "", f"Sandbox error: {e}", -1, False diff --git a/autogpt_platform/backend/backend/copilot/tools/search_docs.py b/autogpt_platform/backend/backend/copilot/tools/search_docs.py index 95b571bdcc..b09fe64a2c 100644 --- a/autogpt_platform/backend/backend/copilot/tools/search_docs.py +++ b/autogpt_platform/backend/backend/copilot/tools/search_docs.py @@ -6,15 +6,16 @@ from typing import Any from prisma.enums import ContentType from backend.copilot.model import ChatSession -from backend.copilot.tools.base import BaseTool -from backend.copilot.tools.models import ( +from backend.data.db_accessors import search + +from .base import BaseTool +from .models import ( DocSearchResult, DocSearchResultsResponse, ErrorResponse, NoResultsResponse, ToolResponseBase, ) -from backend.data.db_accessors import search logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/copilot/tools/test_run_block_details.py b/autogpt_platform/backend/backend/copilot/tools/test_run_block_details.py index c5d3986df2..d06fbb766d 100644 --- a/autogpt_platform/backend/backend/copilot/tools/test_run_block_details.py +++ b/autogpt_platform/backend/backend/copilot/tools/test_run_block_details.py @@ -5,12 +5,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.blocks._base import BlockType -from backend.copilot.tools.models import BlockDetailsResponse -from backend.copilot.tools.run_block import RunBlockTool from backend.data.model import CredentialsMetaInput from backend.integrations.providers import ProviderName from ._test_data import make_session +from .models import BlockDetailsResponse +from .run_block import RunBlockTool _TEST_USER_ID = "test-user-run-block-details" diff --git a/autogpt_platform/backend/backend/copilot/tools/web_fetch.py b/autogpt_platform/backend/backend/copilot/tools/web_fetch.py new file mode 100644 index 0000000000..78ee2f9fe0 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/web_fetch.py @@ -0,0 +1,148 @@ +"""Web fetch tool — safely retrieve public web page content.""" + +import logging +from typing import Any + +import aiohttp +import html2text + +from backend.copilot.model import ChatSession +from backend.util.request import Requests + +from .base import BaseTool +from .models import ErrorResponse, ToolResponseBase, WebFetchResponse + +logger = logging.getLogger(__name__) + +# Limits +_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap +_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15) + +# Content types we'll read as text +_TEXT_CONTENT_TYPES = { + "text/html", + "text/plain", + "text/xml", + "text/csv", + "text/markdown", + "application/json", + "application/xml", + "application/xhtml+xml", + "application/rss+xml", + "application/atom+xml", +} + + +def _is_text_content(content_type: str) -> bool: + base = content_type.split(";")[0].strip().lower() + return base in _TEXT_CONTENT_TYPES or base.startswith("text/") + + +def _html_to_text(html: str) -> str: + h = html2text.HTML2Text() + h.ignore_links = False + h.ignore_images = True + h.body_width = 0 + return h.handle(html) + + +class WebFetchTool(BaseTool): + """Safely fetch content from a public URL using SSRF-protected HTTP.""" + + @property + def name(self) -> str: + return "web_fetch" + + @property + def description(self) -> str: + return ( + "Fetch the content of a public web page by URL. " + "Returns readable text extracted from HTML by default. " + "Useful for reading documentation, articles, and API responses. " + "Only supports HTTP/HTTPS GET requests to public URLs " + "(private/internal network addresses are blocked)." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The public HTTP/HTTPS URL to fetch.", + }, + "extract_text": { + "type": "boolean", + "description": ( + "If true (default), extract readable text from HTML. " + "If false, return raw content." + ), + "default": True, + }, + }, + "required": ["url"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + url: str = (kwargs.get("url") or "").strip() + extract_text: bool = kwargs.get("extract_text", True) + session_id = session.session_id if session else None + + if not url: + return ErrorResponse( + message="Please provide a URL to fetch.", + error="missing_url", + session_id=session_id, + ) + + try: + client = Requests(raise_for_status=False, retry_max_attempts=1) + response = await client.get(url, timeout=_REQUEST_TIMEOUT) + except ValueError as e: + # validate_url raises ValueError for SSRF / blocked IPs + return ErrorResponse( + message=f"URL blocked: {e}", + error="url_blocked", + session_id=session_id, + ) + except Exception as e: + logger.warning(f"[web_fetch] Request failed for {url}: {e}") + return ErrorResponse( + message=f"Failed to fetch URL: {e}", + error="fetch_failed", + session_id=session_id, + ) + + content_type = response.headers.get("content-type", "") + if not _is_text_content(content_type): + return ErrorResponse( + message=f"Non-text content type: {content_type.split(';')[0]}", + error="unsupported_content_type", + session_id=session_id, + ) + + raw = response.content[:_MAX_CONTENT_BYTES] + text = raw.decode("utf-8", errors="replace") + + if extract_text and "html" in content_type.lower(): + text = _html_to_text(text) + + return WebFetchResponse( + message=f"Fetched {url}", + url=response.url, + status_code=response.status, + content_type=content_type.split(";")[0].strip(), + content=text, + truncated=False, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/copilot/tools/workspace_files.py b/autogpt_platform/backend/backend/copilot/tools/workspace_files.py index 9ecbf74052..9960fe1a74 100644 --- a/autogpt_platform/backend/backend/copilot/tools/workspace_files.py +++ b/autogpt_platform/backend/backend/copilot/tools/workspace_files.py @@ -88,7 +88,9 @@ class ListWorkspaceFilesTool(BaseTool): @property def description(self) -> str: return ( - "List files in the user's workspace. " + "List files in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read/Glob tools instead. " "Returns file names, paths, sizes, and metadata. " "Optionally filter by path prefix." ) @@ -204,7 +206,9 @@ class ReadWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Read a file from the user's workspace. " + "Read a file from the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read tool instead. " "Specify either file_id or path to identify the file. " "For small text files, returns content directly. " "For large or binary files, returns metadata and a download URL. " @@ -378,7 +382,9 @@ class WriteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Write or create a file in the user's workspace. " + "Write or create a file in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Write tool instead. " "Provide the content as a base64-encoded string. " f"Maximum file size is {Config().max_file_size_mb}MB. " "Files are saved to the current session's folder by default. " @@ -523,7 +529,7 @@ class DeleteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Delete a file from the user's workspace. " + "Delete a file from the user's persistent workspace (cloud storage). " "Specify either file_id or path to identify the file. " "Paths are scoped to the current session by default. " "Use /sessions//... for cross-session access." diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index fbd3573112..4eadc41333 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -38,6 +38,7 @@ class Flag(str, Enum): AGENT_ACTIVITY = "agent-activity" ENABLE_PLATFORM_PAYMENT = "enable-platform-payment" CHAT = "chat" + COPILOT_SDK = "copilot-sdk" def is_configured() -> bool: diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index fcf1f63878..8817232df6 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -669,6 +669,17 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): mem0_api_key: str = Field(default="", description="Mem0 API key") elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key") + linear_api_key: str = Field( + default="", description="Linear API key for system-level operations" + ) + linear_feature_request_project_id: str = Field( + default="", + description="Linear project ID where feature requests are tracked", + ) + linear_feature_request_team_id: str = Field( + default="", + description="Linear team ID used when creating feature request issues", + ) linear_client_id: str = Field(default="", description="Linear client ID") linear_client_secret: str = Field(default="", description="Linear client secret") diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index d71cca7865..8062457a70 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -897,6 +897,29 @@ files = [ {file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"}, ] +[[package]] +name = "claude-agent-sdk" +version = "0.1.35" +description = "Python SDK for Claude Code" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"}, + {file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"}, + {file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"}, +] + +[package.dependencies] +anyio = ">=4.0.0" +mcp = ">=0.1.0" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"] + [[package]] name = "cleo" version = "2.1.0" @@ -2593,6 +2616,18 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.3" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, + {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, +] + [[package]] name = "huggingface-hub" version = "1.4.1" @@ -3310,6 +3345,39 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mcp" +version = "1.26.0" +description = "Model Context Protocol SDK" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"}, + {file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"}, +] + +[package.dependencies] +anyio = ">=4.5" +httpx = ">=0.27.1" +httpx-sse = ">=0.4" +jsonschema = ">=4.20.0" +pydantic = ">=2.11.0,<3.0.0" +pydantic-settings = ">=2.5.2" +pyjwt = {version = ">=2.10.1", extras = ["crypto"]} +python-multipart = ">=0.0.9" +pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""} +sse-starlette = ">=1.6.1" +starlette = ">=0.27" +typing-extensions = ">=4.9.0" +typing-inspection = ">=0.4.1" +uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""} + +[package.extras] +cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"] +rich = ["rich (>=13.9.4)"] +ws = ["websockets (>=15.0.1)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -5994,7 +6062,7 @@ description = "Python for Window Extensions" optional = false python-versions = "*" groups = ["main"] -markers = "platform_system == \"Windows\"" +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"}, {file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"}, @@ -6974,6 +7042,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sse-starlette" +version = "3.2.0" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"}, + {file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"}, +] + +[package.dependencies] +anyio = ">=4.7.0" +starlette = ">=0.49.1" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + [[package]] name = "stagehand" version = "0.5.9" @@ -8440,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f" +content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index e5f078b500..6467d15f49 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -16,6 +16,7 @@ anthropic = "^0.79.0" apscheduler = "^3.11.1" autogpt-libs = { path = "../autogpt_libs", develop = true } bleach = { extras = ["css"], version = "^6.2.0" } +claude-agent-sdk = "^0.1.0" click = "^8.2.0" cryptography = "^46.0" discord-py = "^2.5.2" diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx index 71ade81a9f..c118057963 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx @@ -15,11 +15,16 @@ import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai"; import { useEffect, useRef, useState } from "react"; import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../../tools/RunAgent/RunAgent"; import { RunBlockTool } from "../../tools/RunBlock/RunBlock"; import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs"; +import { GenericTool } from "../../tools/GenericTool/GenericTool"; import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput"; // --------------------------------------------------------------------------- @@ -254,7 +259,31 @@ export const ChatMessagesContainer = ({ part={part as ToolUIPart} /> ); + case "tool-search_feature_requests": + return ( + + ); + case "tool-create_feature_request": + return ( + + ); default: + // Render a generic tool indicator for SDK built-in + // tools (Read, Glob, Grep, etc.) or any unrecognized tool + if (part.type.startsWith("tool-")) { + return ( + + ); + } return null; } })} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx index 6030665f1c..8a35f939ca 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx @@ -14,6 +14,10 @@ import { Text } from "@/components/atoms/Text/Text"; import { CopilotChatActionsProvider } from "../components/CopilotChatActionsProvider/CopilotChatActionsProvider"; import { CreateAgentTool } from "../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../tools/RunAgent/RunAgent"; @@ -45,6 +49,8 @@ const SECTIONS = [ "Tool: Create Agent", "Tool: Edit Agent", "Tool: View Agent Output", + "Tool: Search Feature Requests", + "Tool: Create Feature Request", "Full Conversation Example", ] as const; @@ -1421,6 +1427,235 @@ export default function StyleguidePage() { + {/* ============================================================= */} + {/* SEARCH FEATURE REQUESTS */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ + {/* ============================================================= */} + {/* CREATE FEATURE REQUEST */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ {/* ============================================================= */} {/* FULL CONVERSATION EXAMPLE */} {/* ============================================================= */} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx new file mode 100644 index 0000000000..fcd4624b6a --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx @@ -0,0 +1,227 @@ +"use client"; + +import type { ToolUIPart } from "ai"; +import { useMemo } from "react"; + +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; +import { + ContentBadge, + ContentCard, + ContentCardDescription, + ContentCardHeader, + ContentCardTitle, + ContentGrid, + ContentMessage, + ContentSuggestionsList, +} from "../../components/ToolAccordion/AccordionContent"; +import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion"; +import { + AccordionIcon, + getAccordionTitle, + getAnimationText, + getFeatureRequestOutput, + isCreatedOutput, + isErrorOutput, + isNoResultsOutput, + isSearchResultsOutput, + ToolIcon, + type FeatureRequestToolType, +} from "./helpers"; + +export interface FeatureRequestToolPart { + type: FeatureRequestToolType; + toolCallId: string; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +interface Props { + part: FeatureRequestToolPart; +} + +function truncate(text: string, maxChars: number): string { + const trimmed = text.trim(); + if (trimmed.length <= maxChars) return trimmed; + return `${trimmed.slice(0, maxChars).trimEnd()}…`; +} + +export function SearchFeatureRequestsTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const searchOutput = + isOutputAvailable && output && isSearchResultsOutput(output) + ? output + : null; + const noResultsOutput = + isOutputAvailable && output && isNoResultsOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && + ((!!searchOutput && searchOutput.count > 0) || + !!noResultsOutput || + !!errorOutput); + + const accordionDescription = + hasExpandableContent && searchOutput + ? `Found ${searchOutput.count} result${searchOutput.count === 1 ? "" : "s"} for "${searchOutput.query}"` + : hasExpandableContent && (noResultsOutput || errorOutput) + ? ((noResultsOutput ?? errorOutput)?.message ?? null) + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {searchOutput && ( + + {searchOutput.results.map((r) => ( + + + {r.title} + + {r.description && ( + + {truncate(r.description, 200)} + + )} + + ))} + + )} + + {noResultsOutput && ( +
+ {noResultsOutput.message} + {noResultsOutput.suggestions && + noResultsOutput.suggestions.length > 0 && ( + + )} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} + +export function CreateFeatureRequestTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const createdOutput = + isOutputAvailable && output && isCreatedOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && (!!createdOutput || !!errorOutput); + + const accordionDescription = + hasExpandableContent && createdOutput + ? createdOutput.issue_title + : hasExpandableContent && errorOutput + ? errorOutput.message + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {createdOutput && ( + + + {createdOutput.issue_title} + +
+ + {createdOutput.is_new_issue ? "New" : "Existing"} + +
+ {createdOutput.message} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx new file mode 100644 index 0000000000..75133905b1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx @@ -0,0 +1,271 @@ +import { + CheckCircleIcon, + LightbulbIcon, + MagnifyingGlassIcon, + PlusCircleIcon, +} from "@phosphor-icons/react"; +import type { ToolUIPart } from "ai"; + +/* ------------------------------------------------------------------ */ +/* Types (local until API client is regenerated) */ +/* ------------------------------------------------------------------ */ + +interface FeatureRequestInfo { + id: string; + identifier: string; + title: string; + description?: string | null; +} + +export interface FeatureRequestSearchResponse { + type: "feature_request_search"; + message: string; + results: FeatureRequestInfo[]; + count: number; + query: string; +} + +export interface FeatureRequestCreatedResponse { + type: "feature_request_created"; + message: string; + issue_id: string; + issue_identifier: string; + issue_title: string; + issue_url: string; + is_new_issue: boolean; + customer_name: string; +} + +interface NoResultsResponse { + type: "no_results"; + message: string; + suggestions?: string[]; +} + +interface ErrorResponse { + type: "error"; + message: string; + error?: string; +} + +export type FeatureRequestOutput = + | FeatureRequestSearchResponse + | FeatureRequestCreatedResponse + | NoResultsResponse + | ErrorResponse; + +export type FeatureRequestToolType = + | "tool-search_feature_requests" + | "tool-create_feature_request" + | string; + +/* ------------------------------------------------------------------ */ +/* Output parsing */ +/* ------------------------------------------------------------------ */ + +function parseOutput(output: unknown): FeatureRequestOutput | null { + if (!output) return null; + if (typeof output === "string") { + const trimmed = output.trim(); + if (!trimmed) return null; + try { + return parseOutput(JSON.parse(trimmed) as unknown); + } catch { + return null; + } + } + if (typeof output === "object") { + const type = (output as { type?: unknown }).type; + if ( + type === "feature_request_search" || + type === "feature_request_created" || + type === "no_results" || + type === "error" + ) { + return output as FeatureRequestOutput; + } + // Fallback structural checks + if ("results" in output && "query" in output) + return output as FeatureRequestSearchResponse; + if ("issue_identifier" in output) + return output as FeatureRequestCreatedResponse; + if ("suggestions" in output && !("error" in output)) + return output as NoResultsResponse; + if ("error" in output || "details" in output) + return output as ErrorResponse; + } + return null; +} + +export function getFeatureRequestOutput( + part: unknown, +): FeatureRequestOutput | null { + if (!part || typeof part !== "object") return null; + return parseOutput((part as { output?: unknown }).output); +} + +/* ------------------------------------------------------------------ */ +/* Type guards */ +/* ------------------------------------------------------------------ */ + +export function isSearchResultsOutput( + output: FeatureRequestOutput, +): output is FeatureRequestSearchResponse { + return ( + output.type === "feature_request_search" || + ("results" in output && "query" in output) + ); +} + +export function isCreatedOutput( + output: FeatureRequestOutput, +): output is FeatureRequestCreatedResponse { + return ( + output.type === "feature_request_created" || "issue_identifier" in output + ); +} + +export function isNoResultsOutput( + output: FeatureRequestOutput, +): output is NoResultsResponse { + return ( + output.type === "no_results" || + ("suggestions" in output && !("error" in output)) + ); +} + +export function isErrorOutput( + output: FeatureRequestOutput, +): output is ErrorResponse { + return output.type === "error" || "error" in output; +} + +/* ------------------------------------------------------------------ */ +/* Accordion metadata */ +/* ------------------------------------------------------------------ */ + +export function getAccordionTitle( + toolType: FeatureRequestToolType, + output: FeatureRequestOutput, +): string { + if (toolType === "tool-search_feature_requests") { + if (isSearchResultsOutput(output)) return "Feature requests"; + if (isNoResultsOutput(output)) return "No feature requests found"; + return "Feature request search error"; + } + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to feature request"; + } + if (isErrorOutput(output)) return "Feature request error"; + return "Feature request"; +} + +/* ------------------------------------------------------------------ */ +/* Animation text */ +/* ------------------------------------------------------------------ */ + +interface AnimationPart { + type: FeatureRequestToolType; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +export function getAnimationText(part: AnimationPart): string { + if (part.type === "tool-search_feature_requests") { + const query = (part.input as { query?: string } | undefined)?.query?.trim(); + const queryText = query ? ` for "${query}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Searching feature requests${queryText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Searching feature requests${queryText}`; + if (isSearchResultsOutput(output)) { + return `Found ${output.count} feature request${output.count === 1 ? "" : "s"}${queryText}`; + } + if (isNoResultsOutput(output)) + return `No feature requests found${queryText}`; + return `Error searching feature requests${queryText}`; + } + case "output-error": + return `Error searching feature requests${queryText}`; + default: + return "Searching feature requests"; + } + } + + // create_feature_request + const title = (part.input as { title?: string } | undefined)?.title?.trim(); + const titleText = title ? ` "${title}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Creating feature request${titleText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Creating feature request${titleText}`; + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to existing feature request"; + } + if (isErrorOutput(output)) return "Error creating feature request"; + return `Created feature request${titleText}`; + } + case "output-error": + return "Error creating feature request"; + default: + return "Creating feature request"; + } +} + +/* ------------------------------------------------------------------ */ +/* Icons */ +/* ------------------------------------------------------------------ */ + +export function ToolIcon({ + toolType, + isStreaming, + isError, +}: { + toolType: FeatureRequestToolType; + isStreaming?: boolean; + isError?: boolean; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? PlusCircleIcon + : MagnifyingGlassIcon; + + return ( + + ); +} + +export function AccordionIcon({ + toolType, +}: { + toolType: FeatureRequestToolType; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? CheckCircleIcon + : LightbulbIcon; + return ; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx new file mode 100644 index 0000000000..677f1d01d1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx @@ -0,0 +1,63 @@ +"use client"; + +import { ToolUIPart } from "ai"; +import { GearIcon } from "@phosphor-icons/react"; +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; + +interface Props { + part: ToolUIPart; +} + +function extractToolName(part: ToolUIPart): string { + // ToolUIPart.type is "tool-{name}", extract the name portion. + return part.type.replace(/^tool-/, ""); +} + +function formatToolName(name: string): string { + // "search_docs" → "Search docs", "Read" → "Read" + return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase()); +} + +function getAnimationText(part: ToolUIPart): string { + const label = formatToolName(extractToolName(part)); + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Running ${label}…`; + case "output-available": + return `${label} completed`; + case "output-error": + return `${label} failed`; + default: + return `Running ${label}…`; + } +} + +export function GenericTool({ part }: Props) { + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = part.state === "output-error"; + + return ( +
+
+ + +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 496a714ba5..8e48931540 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -7066,13 +7066,57 @@ "properties": { "id": { "type": "string", "title": "Id" }, "name": { "type": "string", "title": "Name" }, - "description": { "type": "string", "title": "Description" } + "description": { "type": "string", "title": "Description" }, + "categories": { + "items": { "type": "string" }, + "type": "array", + "title": "Categories" + }, + "input_schema": { + "additionalProperties": true, + "type": "object", + "title": "Input Schema", + "description": "Full JSON schema for block inputs" + }, + "output_schema": { + "additionalProperties": true, + "type": "object", + "title": "Output Schema", + "description": "Full JSON schema for block outputs" + }, + "required_inputs": { + "items": { "$ref": "#/components/schemas/BlockInputFieldInfo" }, + "type": "array", + "title": "Required Inputs", + "description": "List of input fields for this block" + } }, "type": "object", - "required": ["id", "name", "description"], + "required": ["id", "name", "description", "categories"], "title": "BlockInfoSummary", "description": "Summary of a block for search results." }, + "BlockInputFieldInfo": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "type": { "type": "string", "title": "Type" }, + "description": { + "type": "string", + "title": "Description", + "default": "" + }, + "required": { + "type": "boolean", + "title": "Required", + "default": false + }, + "default": { "anyOf": [{}, { "type": "null" }], "title": "Default" } + }, + "type": "object", + "required": ["name", "type"], + "title": "BlockInputFieldInfo", + "description": "Information about a block input field." + }, "BlockListResponse": { "properties": { "type": { @@ -7090,7 +7134,12 @@ "title": "Blocks" }, "count": { "type": "integer", "title": "Count" }, - "query": { "type": "string", "title": "Query" } + "query": { "type": "string", "title": "Query" }, + "usage_hint": { + "type": "string", + "title": "Usage Hint", + "default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs." + } }, "type": "object", "required": ["message", "blocks", "count", "query"], @@ -10495,7 +10544,12 @@ "operation_started", "operation_pending", "operation_in_progress", - "input_validation_error" + "input_validation_error", + "web_fetch", + "bash_exec", + "operation_status", + "feature_request_search", + "feature_request_created" ], "title": "ResponseType", "description": "Types of tool responses." diff --git a/plans/SECRT-1950-claude-ci-optimizations.md b/plans/SECRT-1950-claude-ci-optimizations.md new file mode 100644 index 0000000000..15d1419b0e --- /dev/null +++ b/plans/SECRT-1950-claude-ci-optimizations.md @@ -0,0 +1,165 @@ +# Implementation Plan: SECRT-1950 - Apply E2E CI Optimizations to Claude Code Workflows + +## Ticket +[SECRT-1950](https://linear.app/autogpt/issue/SECRT-1950) + +## Summary +Apply Pwuts's CI performance optimizations from PR #12090 to Claude Code workflows. + +## Reference PR +https://github.com/Significant-Gravitas/AutoGPT/pull/12090 + +--- + +## Analysis + +### Current State (claude.yml) + +**pnpm caching (lines 104-118):** +```yaml +- name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: "22" + +- name: Enable corepack + run: corepack enable + +- name: Set pnpm store directory + run: | + pnpm config set store-dir ~/.pnpm-store + echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV + +- name: Cache frontend dependencies + uses: actions/cache@v5 + with: + path: ~/.pnpm-store + key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} + restore-keys: | + ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} + ${{ runner.os }}-pnpm- +``` + +**Docker setup (lines 134-165):** +- Uses `docker-buildx-action@v3` +- Has manual Docker image caching via `actions/cache` +- Runs `docker compose up` without buildx bake optimization + +### Pwuts's Optimizations (PR #12090) + +1. **Simplified pnpm caching** - Use `setup-node` built-in cache: +```yaml +- name: Enable corepack + run: corepack enable + +- name: Set up Node + uses: actions/setup-node@v6 + with: + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml +``` + +2. **Docker build caching via buildx bake**: +```yaml +- name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver: docker-container + driver-opts: network=host + +- name: Expose GHA cache to docker buildx CLI + uses: crazy-max/ghaction-github-runtime@v3 + +- name: Build Docker images (with cache) + run: | + pip install pyyaml + docker compose -f docker-compose.yml config > docker-compose.resolved.yml + python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \ + --source docker-compose.resolved.yml \ + --cache-from "type=gha" \ + --cache-to "type=gha,mode=max" \ + ... + docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load +``` + +--- + +## Proposed Changes + +### 1. Update pnpm caching in `claude.yml` + +**Before:** +- Manual cache key generation +- Separate `actions/cache` step +- Manual pnpm store directory config + +**After:** +- Use `setup-node` built-in `cache: "pnpm"` option +- Remove manual cache step +- Keep `corepack enable` before `setup-node` + +### 2. Update Docker build in `claude.yml` + +**Before:** +- Manual Docker layer caching via `actions/cache` with `/tmp/.buildx-cache` +- Simple `docker compose build` + +**After:** +- Use `crazy-max/ghaction-github-runtime@v3` to expose GHA cache +- Use `docker-ci-fix-compose-build-cache.py` script +- Build with `docker buildx bake` + +### 3. Apply same changes to other Claude workflows + +- `claude-dependabot.yml` - Check if it has similar patterns +- `claude-ci-failure-auto-fix.yml` - Check if it has similar patterns +- `copilot-setup-steps.yml` - Reusable workflow, may be the source of truth + +--- + +## Files to Modify + +1. `.github/workflows/claude.yml` +2. `.github/workflows/claude-dependabot.yml` (if applicable) +3. `.github/workflows/claude-ci-failure-auto-fix.yml` (if applicable) + +## Dependencies + +- PR #12090 must be merged first (provides the `docker-ci-fix-compose-build-cache.py` script) +- Backend Dockerfile optimizations (already in PR #12090) + +--- + +## Test Plan + +1. Create PR with changes +2. Trigger Claude workflow manually or via `@claude` mention on a test issue +3. Compare CI runtime before/after +4. Verify Claude agent still works correctly (can checkout, build, run tests) + +--- + +## Risk Assessment + +**Low risk:** +- These are CI infrastructure changes, not code changes +- If caching fails, builds fall back to uncached (slower but works) +- Changes mirror proven patterns from PR #12090 + +--- + +## Questions for Reviewer + +1. Should we wait for PR #12090 to merge before creating this PR? +2. Does `copilot-setup-steps.yml` need updating, or is it a separate concern? +3. Any concerns about cache key collisions between frontend E2E and Claude workflows? + +--- + +## Verified + +- ✅ **`claude-dependabot.yml`**: Has same pnpm caching pattern as `claude.yml` (manual `actions/cache`) — NEEDS UPDATE +- ✅ **`claude-ci-failure-auto-fix.yml`**: Simple workflow with no pnpm or Docker caching — NO CHANGES NEEDED +- ✅ **Script path**: `docker-ci-fix-compose-build-cache.py` will be at `.github/workflows/scripts/` after PR #12090 merges +- ✅ **Test seed caching**: NOT APPLICABLE — Claude workflows spin up a dev environment but don't run E2E tests with pre-seeded data. The seed caching in PR #12090 is specific to the frontend E2E test suite which needs consistent test data. Claude just needs the services running.