mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9314052e89 | |||
| 08890f189f | |||
| 180a35f013 | |||
| 18365e0323 | |||
| 9a743ff51a | |||
| 29577935b4 | |||
| 7498353ed5 | |||
| b62bdfd143 | |||
| fb98faf4ac | |||
| a8f62aa30c | |||
| 1a7449b03a | |||
| 1091901be2 | |||
| 15160f6733 | |||
| 13dba59bb8 | |||
| 478c998f04 | |||
| a9fc93ffbf | |||
| cc100c0d10 | |||
| 7bc3300981 | |||
| 3e0283796e | |||
| cd0175d83e | |||
| f313cfceb9 | |||
| fb0108f946 | |||
| 6b29a82de3 |
@@ -1,433 +0,0 @@
|
||||
name: Auto-Fix Tagged Issue with OpenHands
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
max_iterations:
|
||||
required: false
|
||||
type: number
|
||||
default: 50
|
||||
macro:
|
||||
required: false
|
||||
type: string
|
||||
default: "@openhands-agent"
|
||||
target_branch:
|
||||
required: false
|
||||
type: string
|
||||
default: "main"
|
||||
description: "Target branch to pull and create PR against"
|
||||
pr_type:
|
||||
required: false
|
||||
type: string
|
||||
default: "draft"
|
||||
description: "The PR type that is going to be created (draft, ready)"
|
||||
LLM_MODEL:
|
||||
required: false
|
||||
type: string
|
||||
default: "anthropic/claude-sonnet-4-20250514"
|
||||
LLM_API_VERSION:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
base_container_image:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
description: "Custom sandbox env"
|
||||
runner:
|
||||
required: false
|
||||
type: string
|
||||
default: "ubuntu-latest"
|
||||
secrets:
|
||||
LLM_MODEL:
|
||||
required: false
|
||||
LLM_API_KEY:
|
||||
required: true
|
||||
LLM_BASE_URL:
|
||||
required: false
|
||||
PAT_TOKEN:
|
||||
required: false
|
||||
PAT_USERNAME:
|
||||
required: false
|
||||
|
||||
issues:
|
||||
types: [labeled]
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
pull_request_review:
|
||||
types: [submitted]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event_name == 'workflow_call' ||
|
||||
github.event.label.name == 'fix-me' ||
|
||||
github.event.label.name == 'fix-me-experimental' ||
|
||||
(
|
||||
((github.event_name == 'issue_comment' || github.event_name == 'pull_request_review_comment') &&
|
||||
contains(github.event.comment.body, inputs.macro || '@openhands-agent') &&
|
||||
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'COLLABORATOR' || github.event.comment.author_association == 'MEMBER')
|
||||
) ||
|
||||
|
||||
(github.event_name == 'pull_request_review' &&
|
||||
contains(github.event.review.body, inputs.macro || '@openhands-agent') &&
|
||||
(github.event.review.author_association == 'OWNER' || github.event.review.author_association == 'COLLABORATOR' || github.event.review.author_association == 'MEMBER')
|
||||
)
|
||||
)
|
||||
runs-on: "${{ inputs.runner || 'ubuntu-latest' }}"
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- name: Upgrade pip
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Get latest versions and create requirements.txt
|
||||
run: |
|
||||
python -m pip index versions openhands-ai > openhands_versions.txt
|
||||
OPENHANDS_VERSION=$(head -n 1 openhands_versions.txt | awk '{print $2}' | tr -d '()')
|
||||
|
||||
# Create a new requirements.txt locally within the workflow, ensuring no reference to the repo's file
|
||||
echo "openhands-ai==${OPENHANDS_VERSION}" > /tmp/requirements.txt
|
||||
cat /tmp/requirements.txt
|
||||
|
||||
- name: Cache pip dependencies
|
||||
if: |
|
||||
!(
|
||||
github.event.label.name == 'fix-me-experimental' ||
|
||||
(
|
||||
(github.event_name == 'issue_comment' || github.event_name == 'pull_request_review_comment') &&
|
||||
contains(github.event.comment.body, '@openhands-agent-exp')
|
||||
) ||
|
||||
(
|
||||
github.event_name == 'pull_request_review' &&
|
||||
contains(github.event.review.body, '@openhands-agent-exp')
|
||||
)
|
||||
)
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ${{ env.pythonLocation }}/lib/python3.12/site-packages/*
|
||||
key: ${{ runner.os }}-pip-openhands-resolver-${{ hashFiles('/tmp/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-openhands-resolver-${{ hashFiles('/tmp/requirements.txt') }}
|
||||
|
||||
- name: Check required environment variables
|
||||
env:
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL || inputs.LLM_MODEL }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
LLM_API_VERSION: ${{ inputs.LLM_API_VERSION }}
|
||||
PAT_TOKEN: ${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC }}
|
||||
PAT_USERNAME: ${{ secrets.PAT_USERNAME }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
required_vars=("LLM_API_KEY")
|
||||
for var in "${required_vars[@]}"; do
|
||||
if [ -z "${!var}" ]; then
|
||||
echo "Error: Required environment variable $var is not set."
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
# Check optional variables and warn about fallbacks
|
||||
if [ -z "$LLM_BASE_URL" ]; then
|
||||
echo "Warning: LLM_BASE_URL is not set, will use default API endpoint"
|
||||
fi
|
||||
|
||||
if [ -z "$PAT_TOKEN" ]; then
|
||||
echo "Warning: PAT_TOKEN is not set, falling back to GITHUB_TOKEN"
|
||||
fi
|
||||
|
||||
if [ -z "$PAT_USERNAME" ]; then
|
||||
echo "Warning: PAT_USERNAME is not set, will use openhands-agent"
|
||||
fi
|
||||
|
||||
- name: Set environment variables
|
||||
env:
|
||||
REVIEW_BODY: ${{ github.event.review.body || '' }}
|
||||
run: |
|
||||
# Handle pull request events first
|
||||
if [ -n "${{ github.event.pull_request.number }}" ]; then
|
||||
echo "ISSUE_NUMBER=${{ github.event.pull_request.number }}" >> $GITHUB_ENV
|
||||
echo "ISSUE_TYPE=pr" >> $GITHUB_ENV
|
||||
# Handle pull request review events
|
||||
elif [ -n "$REVIEW_BODY" ]; then
|
||||
echo "ISSUE_NUMBER=${{ github.event.pull_request.number }}" >> $GITHUB_ENV
|
||||
echo "ISSUE_TYPE=pr" >> $GITHUB_ENV
|
||||
# Handle issue comment events that reference a PR
|
||||
elif [ -n "${{ github.event.issue.pull_request }}" ]; then
|
||||
echo "ISSUE_NUMBER=${{ github.event.issue.number }}" >> $GITHUB_ENV
|
||||
echo "ISSUE_TYPE=pr" >> $GITHUB_ENV
|
||||
# Handle regular issue events
|
||||
else
|
||||
echo "ISSUE_NUMBER=${{ github.event.issue.number }}" >> $GITHUB_ENV
|
||||
echo "ISSUE_TYPE=issue" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
if [ -n "$REVIEW_BODY" ]; then
|
||||
echo "COMMENT_ID=${{ github.event.review.id || 'None' }}" >> $GITHUB_ENV
|
||||
else
|
||||
echo "COMMENT_ID=${{ github.event.comment.id || 'None' }}" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
echo "MAX_ITERATIONS=${{ inputs.max_iterations || 50 }}" >> $GITHUB_ENV
|
||||
echo "SANDBOX_ENV_GITHUB_TOKEN=${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC || github.token }}" >> $GITHUB_ENV
|
||||
echo "SANDBOX_BASE_CONTAINER_IMAGE=${{ inputs.base_container_image }}" >> $GITHUB_ENV
|
||||
|
||||
# Set branch variables
|
||||
echo "TARGET_BRANCH=${{ inputs.target_branch || 'main' }}" >> $GITHUB_ENV
|
||||
|
||||
- name: Comment on issue with start message
|
||||
uses: actions/github-script@v9
|
||||
with:
|
||||
github-token: ${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC || github.token }}
|
||||
script: |
|
||||
const issueType = process.env.ISSUE_TYPE;
|
||||
github.rest.issues.createComment({
|
||||
issue_number: ${{ env.ISSUE_NUMBER }},
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: `[OpenHands](https://github.com/OpenHands/OpenHands) started fixing the ${issueType}! You can monitor the progress [here](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}).`
|
||||
});
|
||||
|
||||
- name: Install OpenHands
|
||||
id: install_openhands
|
||||
uses: actions/github-script@v9
|
||||
env:
|
||||
COMMENT_BODY: ${{ github.event.comment.body || '' }}
|
||||
REVIEW_BODY: ${{ github.event.review.body || '' }}
|
||||
LABEL_NAME: ${{ github.event.label.name || '' }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
with:
|
||||
script: |
|
||||
const commentBody = process.env.COMMENT_BODY.trim();
|
||||
const reviewBody = process.env.REVIEW_BODY.trim();
|
||||
const labelName = process.env.LABEL_NAME.trim();
|
||||
const eventName = process.env.EVENT_NAME.trim();
|
||||
// Check conditions
|
||||
const isExperimentalLabel = labelName === "fix-me-experimental";
|
||||
const isIssueCommentExperimental =
|
||||
(eventName === "issue_comment" || eventName === "pull_request_review_comment") &&
|
||||
commentBody.includes("@openhands-agent-exp");
|
||||
const isReviewCommentExperimental =
|
||||
eventName === "pull_request_review" && reviewBody.includes("@openhands-agent-exp");
|
||||
|
||||
// Set output variable
|
||||
core.setOutput('isExperimental', isExperimentalLabel || isIssueCommentExperimental || isReviewCommentExperimental);
|
||||
|
||||
// Perform package installation
|
||||
if (isExperimentalLabel || isIssueCommentExperimental || isReviewCommentExperimental) {
|
||||
console.log("Installing experimental OpenHands...");
|
||||
|
||||
await exec.exec("pip install git+https://github.com/openhands/openhands.git");
|
||||
} else {
|
||||
console.log("Installing from requirements.txt...");
|
||||
|
||||
await exec.exec("pip install -r /tmp/requirements.txt");
|
||||
}
|
||||
|
||||
- name: Attempt to resolve issue
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC || github.token }}
|
||||
GITHUB_USERNAME: ${{ secrets.PAT_USERNAME || 'openhands-agent' }}
|
||||
GIT_USERNAME: ${{ secrets.PAT_USERNAME || 'openhands-agent' }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL || inputs.LLM_MODEL }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
LLM_API_VERSION: ${{ inputs.LLM_API_VERSION }}
|
||||
PYTHONPATH: ""
|
||||
run: |
|
||||
cd /tmp && python -m openhands.resolver.resolve_issue \
|
||||
--selected-repo ${{ github.repository }} \
|
||||
--issue-number ${{ env.ISSUE_NUMBER }} \
|
||||
--issue-type ${{ env.ISSUE_TYPE }} \
|
||||
--max-iterations ${{ env.MAX_ITERATIONS }} \
|
||||
--comment-id ${{ env.COMMENT_ID }} \
|
||||
--is-experimental ${{ steps.install_openhands.outputs.isExperimental }}
|
||||
|
||||
- name: Check resolution result
|
||||
id: check_result
|
||||
run: |
|
||||
if cd /tmp && grep -q '"success":true' output/output.jsonl; then
|
||||
echo "RESOLUTION_SUCCESS=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "RESOLUTION_SUCCESS=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Upload output.jsonl as artifact
|
||||
uses: actions/upload-artifact@v7
|
||||
if: always() # Upload even if the previous steps fail
|
||||
with:
|
||||
name: resolver-output
|
||||
path: /tmp/output/output.jsonl
|
||||
retention-days: 30 # Keep the artifact for 30 days
|
||||
|
||||
- name: Create draft PR or push branch
|
||||
if: always() # Create PR or branch even if the previous steps fail
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC || github.token }}
|
||||
GITHUB_USERNAME: ${{ secrets.PAT_USERNAME || 'openhands-agent' }}
|
||||
GIT_USERNAME: ${{ secrets.PAT_USERNAME || 'openhands-agent' }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL || inputs.LLM_MODEL }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
|
||||
LLM_API_VERSION: ${{ inputs.LLM_API_VERSION }}
|
||||
PYTHONPATH: ""
|
||||
run: |
|
||||
if [ "${{ steps.check_result.outputs.RESOLUTION_SUCCESS }}" == "true" ]; then
|
||||
cd /tmp && python -m openhands.resolver.send_pull_request \
|
||||
--issue-number ${{ env.ISSUE_NUMBER }} \
|
||||
--target-branch ${{ env.TARGET_BRANCH }} \
|
||||
--pr-type ${{ inputs.pr_type || 'draft' }} \
|
||||
--reviewer ${{ github.actor }} | tee pr_result.txt && \
|
||||
grep "PR created" pr_result.txt | sed 's/.*\///g' > pr_number.txt
|
||||
else
|
||||
cd /tmp && python -m openhands.resolver.send_pull_request \
|
||||
--issue-number ${{ env.ISSUE_NUMBER }} \
|
||||
--pr-type branch \
|
||||
--send-on-failure | tee branch_result.txt && \
|
||||
grep "branch created" branch_result.txt | sed 's/.*\///g; s/.expand=1//g' > branch_name.txt
|
||||
fi
|
||||
|
||||
# Step leaves comment for when agent is invoked on PR
|
||||
- name: Analyze Push Logs (Updated PR or No Changes) # Skip comment if PR update was successful OR leave comment if the agent made no code changes
|
||||
uses: actions/github-script@v9
|
||||
if: always()
|
||||
env:
|
||||
AGENT_RESPONDED: ${{ env.AGENT_RESPONDED || 'false' }}
|
||||
ISSUE_NUMBER: ${{ env.ISSUE_NUMBER }}
|
||||
with:
|
||||
github-token: ${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC || github.token }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const issueNumber = process.env.ISSUE_NUMBER;
|
||||
let logContent = '';
|
||||
|
||||
try {
|
||||
logContent = fs.readFileSync('/tmp/pr_result.txt', 'utf8').trim();
|
||||
} catch (error) {
|
||||
console.error('Error reading pr_result.txt file:', error);
|
||||
}
|
||||
|
||||
const noChangesMessage = `No changes to commit for issue #${issueNumber}. Skipping commit.`;
|
||||
|
||||
// Check logs from send_pull_request.py (pushes code to GitHub)
|
||||
if (logContent.includes("Updated pull request")) {
|
||||
console.log("Updated pull request found. Skipping comment.");
|
||||
process.env.AGENT_RESPONDED = 'true';
|
||||
} else if (logContent.includes(noChangesMessage)) {
|
||||
github.rest.issues.createComment({
|
||||
issue_number: issueNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: `The workflow to fix this issue encountered an error. Openhands failed to create any code changes.`
|
||||
});
|
||||
process.env.AGENT_RESPONDED = 'true';
|
||||
}
|
||||
|
||||
# Step leaves comment for when agent is invoked on issue
|
||||
- name: Comment on issue # Comment link to either PR or branch created by agent
|
||||
uses: actions/github-script@v9
|
||||
if: always() # Comment on issue even if the previous steps fail
|
||||
env:
|
||||
AGENT_RESPONDED: ${{ env.AGENT_RESPONDED || 'false' }}
|
||||
ISSUE_NUMBER: ${{ env.ISSUE_NUMBER }}
|
||||
RESOLUTION_SUCCESS: ${{ steps.check_result.outputs.RESOLUTION_SUCCESS }}
|
||||
with:
|
||||
github-token: ${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC || github.token }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const issueNumber = process.env.ISSUE_NUMBER;
|
||||
const success = process.env.RESOLUTION_SUCCESS === 'true';
|
||||
|
||||
let prNumber = '';
|
||||
let branchName = '';
|
||||
let resultExplanation = '';
|
||||
|
||||
try {
|
||||
if (success) {
|
||||
prNumber = fs.readFileSync('/tmp/pr_number.txt', 'utf8').trim();
|
||||
} else {
|
||||
branchName = fs.readFileSync('/tmp/branch_name.txt', 'utf8').trim();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error reading file:', error);
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
if (!success){
|
||||
// Read result_explanation from JSON file for failed resolution
|
||||
const outputFilePath = path.resolve('/tmp/output/output.jsonl');
|
||||
if (fs.existsSync(outputFilePath)) {
|
||||
const outputContent = fs.readFileSync(outputFilePath, 'utf8');
|
||||
const jsonLines = outputContent.split('\n').filter(line => line.trim() !== '');
|
||||
|
||||
if (jsonLines.length > 0) {
|
||||
// First entry in JSON lines has the key 'result_explanation'
|
||||
const firstEntry = JSON.parse(jsonLines[0]);
|
||||
resultExplanation = firstEntry.result_explanation || '';
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error){
|
||||
console.error('Error reading file:', error);
|
||||
}
|
||||
|
||||
// Check "success" log from resolver output
|
||||
if (success && prNumber) {
|
||||
github.rest.issues.createComment({
|
||||
issue_number: issueNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: `A potential fix has been generated and a draft PR #${prNumber} has been created. Please review the changes.`
|
||||
});
|
||||
process.env.AGENT_RESPONDED = 'true';
|
||||
} else if (!success && branchName) {
|
||||
let commentBody = `An attempt was made to automatically fix this issue, but it was unsuccessful. A branch named '${branchName}' has been created with the attempted changes. You can view the branch [here](https://github.com/${context.repo.owner}/${context.repo.repo}/tree/${branchName}). Manual intervention may be required.`;
|
||||
|
||||
if (resultExplanation) {
|
||||
commentBody += `\n\nAdditional details about the failure:\n${resultExplanation}`;
|
||||
}
|
||||
|
||||
github.rest.issues.createComment({
|
||||
issue_number: issueNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: commentBody
|
||||
});
|
||||
process.env.AGENT_RESPONDED = 'true';
|
||||
}
|
||||
|
||||
# Leave error comment when both PR/Issue comment handling fail
|
||||
- name: Fallback Error Comment
|
||||
uses: actions/github-script@v9
|
||||
if: ${{ env.AGENT_RESPONDED == 'false' }} # Only run if no conditions were met in previous steps
|
||||
env:
|
||||
ISSUE_NUMBER: ${{ env.ISSUE_NUMBER }}
|
||||
with:
|
||||
github-token: ${{ secrets.OPENHANDS_BOT_GITHUB_PAT_PUBLIC || github.token }}
|
||||
script: |
|
||||
const issueNumber = process.env.ISSUE_NUMBER;
|
||||
|
||||
github.rest.issues.createComment({
|
||||
issue_number: issueNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: `The workflow to fix this issue encountered an error. Please check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for more information.`
|
||||
});
|
||||
@@ -60,6 +60,7 @@ repos:
|
||||
lxml,
|
||||
"openhands-sdk==1.17.0",
|
||||
"openhands-tools==1.17.0",
|
||||
"sqlalchemy>=2.0",
|
||||
]
|
||||
# To see gaps add `--html-report mypy-report/`
|
||||
entry: mypy --config-file dev_config/python/mypy.ini openhands/
|
||||
|
||||
@@ -12,9 +12,6 @@ disable_error_code = type-abstract
|
||||
# Exclude third-party runtime directory from type checking
|
||||
exclude = (third_party/|enterprise/)
|
||||
|
||||
[mypy-openhands.memory.condenser.impl.*]
|
||||
disable_error_code = override
|
||||
|
||||
[mypy-openai.*]
|
||||
follow_imports = skip
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -50,6 +50,7 @@ repos:
|
||||
- ./
|
||||
- stripe==11.5.0
|
||||
- pygithub==2.6.1
|
||||
- sqlalchemy>=2.0
|
||||
# Use -p (package) to avoid dual module name conflict when using MYPYPATH
|
||||
# MYPYPATH=enterprise allows resolving bare imports like "from integrations.xxx"
|
||||
# Note: tests package excluded to avoid conflict with core openhands tests
|
||||
|
||||
@@ -429,6 +429,11 @@ class GitHubDataCollector:
|
||||
- Num openhands review comments
|
||||
"""
|
||||
pr_number = openhands_pr.pr_number
|
||||
if openhands_pr.installation_id is None:
|
||||
logger.warning(
|
||||
f'Skipping PR {openhands_pr.repo_name}#{pr_number}: missing installation_id'
|
||||
)
|
||||
return
|
||||
installation_id = int(openhands_pr.installation_id)
|
||||
repo_id = openhands_pr.repo_id
|
||||
|
||||
|
||||
@@ -59,11 +59,11 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
)
|
||||
# Create the customer in stripe (only include email if available)
|
||||
create_params: dict = {'metadata': {'org_id': str(org.id)}}
|
||||
if org.contact_email:
|
||||
create_params['email'] = org.contact_email
|
||||
customer = await stripe.Customer.create_async(**create_params)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
async with a_session_maker() as session:
|
||||
@@ -108,11 +108,14 @@ async def migrate_customer(session, user_id: str, org: Org):
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
# Only include email if available to avoid sending empty strings to Stripe
|
||||
modify_params: dict = {
|
||||
'id': stripe_customer.stripe_customer_id,
|
||||
'metadata': {'user_id': '', 'org_id': str(org.id)},
|
||||
}
|
||||
if org.contact_email:
|
||||
modify_params['email'] = org.contact_email
|
||||
customer = await stripe.Customer.modify_async(**modify_params)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
|
||||
@@ -49,133 +49,6 @@ def _strip_none_and_empty(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _next_server_name(existing: Mapping[str, Any], base_name: str) -> str:
|
||||
if base_name not in existing:
|
||||
return base_name
|
||||
|
||||
suffix = 1
|
||||
while f'{base_name}_{suffix}' in existing:
|
||||
suffix += 1
|
||||
return f'{base_name}_{suffix}'
|
||||
|
||||
|
||||
def _normalize_mcp_config(value: Any) -> Any:
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
|
||||
raw_mcp_servers = value.get('mcpServers')
|
||||
if isinstance(raw_mcp_servers, Mapping):
|
||||
mcp_servers = dict(raw_mcp_servers)
|
||||
return {'mcpServers': mcp_servers} if mcp_servers else None
|
||||
|
||||
if not any(
|
||||
key in value for key in ('sse_servers', 'stdio_servers', 'shttp_servers')
|
||||
):
|
||||
return value
|
||||
|
||||
servers: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for entry in value.get('sse_servers', []) or []:
|
||||
if isinstance(entry, str):
|
||||
entry = {'url': entry}
|
||||
if not isinstance(entry, Mapping) or not isinstance(entry.get('url'), str):
|
||||
continue
|
||||
|
||||
server: dict[str, Any] = {'url': entry['url'], 'transport': 'sse'}
|
||||
if entry.get('api_key') is not None:
|
||||
server['auth'] = entry.get('api_key')
|
||||
servers[_next_server_name(servers, 'sse')] = server
|
||||
|
||||
for entry in value.get('shttp_servers', []) or []:
|
||||
if isinstance(entry, str):
|
||||
entry = {'url': entry}
|
||||
if not isinstance(entry, Mapping) or not isinstance(entry.get('url'), str):
|
||||
continue
|
||||
|
||||
server = {'url': entry['url']}
|
||||
if entry.get('api_key') is not None:
|
||||
server['auth'] = entry.get('api_key')
|
||||
if entry.get('timeout') is not None:
|
||||
server['timeout'] = entry.get('timeout')
|
||||
servers[_next_server_name(servers, 'shttp')] = server
|
||||
|
||||
for entry in value.get('stdio_servers', []) or []:
|
||||
if not isinstance(entry, Mapping) or not isinstance(entry.get('command'), str):
|
||||
continue
|
||||
|
||||
server = {'command': entry['command']}
|
||||
if entry.get('args') is not None:
|
||||
server['args'] = entry.get('args')
|
||||
if entry.get('env') is not None:
|
||||
server['env'] = entry.get('env')
|
||||
base_name = entry.get('name') if isinstance(entry.get('name'), str) else 'stdio'
|
||||
servers[_next_server_name(servers, base_name)] = server
|
||||
|
||||
return {'mcpServers': servers} if servers else None
|
||||
|
||||
|
||||
def _legacy_api_key(auth_value: Any) -> str | None:
|
||||
if isinstance(auth_value, str) and auth_value != 'oauth':
|
||||
return auth_value
|
||||
return None
|
||||
|
||||
|
||||
def _to_legacy_mcp_config(value: Any) -> Any:
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
|
||||
raw_mcp_servers = value.get('mcpServers')
|
||||
if not isinstance(raw_mcp_servers, Mapping):
|
||||
return value
|
||||
|
||||
legacy: dict[str, list[Any]] = {
|
||||
'sse_servers': [],
|
||||
'stdio_servers': [],
|
||||
'shttp_servers': [],
|
||||
}
|
||||
|
||||
for server_name, server_config in raw_mcp_servers.items():
|
||||
if not isinstance(server_config, Mapping):
|
||||
continue
|
||||
|
||||
url = server_config.get('url')
|
||||
if isinstance(url, str):
|
||||
entry: dict[str, Any] = {'url': url}
|
||||
api_key = _legacy_api_key(server_config.get('auth'))
|
||||
if api_key is not None:
|
||||
entry['api_key'] = api_key
|
||||
if server_config.get('transport') == 'sse':
|
||||
legacy['sse_servers'].append(entry)
|
||||
else:
|
||||
if server_config.get('timeout') is not None:
|
||||
entry['timeout'] = server_config.get('timeout')
|
||||
legacy['shttp_servers'].append(entry)
|
||||
continue
|
||||
|
||||
command = server_config.get('command')
|
||||
if not isinstance(command, str):
|
||||
continue
|
||||
|
||||
entry = {'name': server_name, 'command': command}
|
||||
if server_config.get('args') is not None:
|
||||
entry['args'] = server_config.get('args')
|
||||
if server_config.get('env') is not None:
|
||||
entry['env'] = server_config.get('env')
|
||||
legacy['stdio_servers'].append(entry)
|
||||
|
||||
return legacy
|
||||
|
||||
|
||||
def _normalize_nested_mcp_config(settings: Mapping[str, Any] | None) -> dict[str, Any]:
|
||||
normalized = dict(settings or {})
|
||||
mcp_config = _normalize_mcp_config(normalized.get('mcp_config'))
|
||||
if mcp_config is None:
|
||||
normalized.pop('mcp_config', None)
|
||||
else:
|
||||
normalized['mcp_config'] = mcp_config
|
||||
return normalized
|
||||
|
||||
|
||||
def _build_user_agent_settings(row: Mapping[str, Any]) -> dict[str, Any]:
|
||||
generated = _strip_none_and_empty(
|
||||
{
|
||||
@@ -189,14 +62,10 @@ def _build_user_agent_settings(row: Mapping[str, Any]) -> dict[str, Any]:
|
||||
'enabled': row['enable_default_condenser'],
|
||||
'max_size': row['condenser_max_size'],
|
||||
},
|
||||
'mcp_config': _normalize_mcp_config(row['mcp_config']),
|
||||
'mcp_config': row['mcp_config'],
|
||||
}
|
||||
)
|
||||
merged = _deep_merge(
|
||||
generated,
|
||||
_normalize_nested_mcp_config(row.get('agent_settings')),
|
||||
)
|
||||
return _normalize_nested_mcp_config(merged)
|
||||
return _deep_merge(generated, row.get('agent_settings') or {})
|
||||
|
||||
|
||||
def _build_user_conversation_settings(row: Mapping[str, Any]) -> dict[str, Any]:
|
||||
@@ -218,14 +87,10 @@ def _build_org_member_agent_settings_diff(row: Mapping[str, Any]) -> dict[str, A
|
||||
'model': row['llm_model'],
|
||||
'base_url': row['llm_base_url'],
|
||||
},
|
||||
'mcp_config': _normalize_mcp_config(row['mcp_config']),
|
||||
'mcp_config': row['mcp_config'],
|
||||
}
|
||||
)
|
||||
merged = _deep_merge(
|
||||
generated,
|
||||
_normalize_nested_mcp_config(row.get('agent_settings_diff')),
|
||||
)
|
||||
return _normalize_nested_mcp_config(merged)
|
||||
return _deep_merge(generated, row.get('agent_settings_diff') or {})
|
||||
|
||||
|
||||
def _build_org_member_conversation_settings_diff(
|
||||
@@ -248,14 +113,10 @@ def _build_org_agent_settings(row: Mapping[str, Any]) -> dict[str, Any]:
|
||||
'enabled': row['enable_default_condenser'],
|
||||
'max_size': row['condenser_max_size'],
|
||||
},
|
||||
'mcp_config': _normalize_mcp_config(row['mcp_config']),
|
||||
'mcp_config': row['mcp_config'],
|
||||
}
|
||||
)
|
||||
merged = _deep_merge(
|
||||
generated,
|
||||
_normalize_nested_mcp_config(row.get('agent_settings')),
|
||||
)
|
||||
return _normalize_nested_mcp_config(merged)
|
||||
return _deep_merge(generated, row.get('agent_settings') or {})
|
||||
|
||||
|
||||
def _build_org_conversation_settings(row: Mapping[str, Any]) -> dict[str, Any]:
|
||||
@@ -311,9 +172,7 @@ def _legacy_org_member_values(row: Mapping[str, Any]) -> dict[str, Any]:
|
||||
'max_iterations': _get_nested_value(
|
||||
conversation_settings_diff, 'max_iterations'
|
||||
),
|
||||
'mcp_config': _to_legacy_mcp_config(
|
||||
_get_nested_value(agent_settings_diff, 'mcp_config')
|
||||
),
|
||||
'mcp_config': _get_nested_value(agent_settings_diff, 'mcp_config'),
|
||||
}
|
||||
|
||||
|
||||
@@ -337,9 +196,7 @@ def _legacy_org_values(row: Mapping[str, Any]) -> dict[str, Any]:
|
||||
'enable_default_condenser': (
|
||||
True if condenser_enabled is None else condenser_enabled
|
||||
),
|
||||
'mcp_config': _to_legacy_mcp_config(
|
||||
_get_nested_value(agent_settings, 'mcp_config')
|
||||
),
|
||||
'mcp_config': _get_nested_value(agent_settings, 'mcp_config'),
|
||||
'condenser_max_size': _get_nested_value(
|
||||
agent_settings, 'condenser', 'max_size'
|
||||
),
|
||||
|
||||
@@ -106,8 +106,15 @@ if GITHUB_APP_CLIENT_ID:
|
||||
|
||||
# Add GitLab integration router only if GITLAB_APP_CLIENT_ID is set
|
||||
if GITLAB_APP_CLIENT_ID:
|
||||
# Make sure that the callback processor is loaded here so we don't get an error when deserializing
|
||||
from integrations.gitlab.gitlab_v1_callback_processor import ( # noqa: E402
|
||||
GitlabV1CallbackProcessor,
|
||||
)
|
||||
from server.routes.integration.gitlab import gitlab_integration_router # noqa: E402
|
||||
|
||||
# Bludgeon mypy into not deleting my import
|
||||
logger.debug(f'Loaded {GitlabV1CallbackProcessor.__name__}')
|
||||
|
||||
base_app.include_router(gitlab_integration_router)
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
|
||||
@@ -703,41 +703,6 @@ async def accept_tos(request: Request):
|
||||
return response
|
||||
|
||||
|
||||
@api_router.get('/onboarding_status')
|
||||
async def onboarding_status(request: Request):
|
||||
"""Return whether the current user must still complete onboarding.
|
||||
|
||||
Kept as a dedicated endpoint instead of riding on ``GET /api/v1/settings``
|
||||
(the natural home for fields like ``email_verified``) because the settings
|
||||
response is heavyweight: ``SaasSettingsStore.load`` joins User, Org, and
|
||||
OrgMember rows and deep-merges the org-level and member-level
|
||||
``agent_settings`` before returning. Onboarding gating runs on every
|
||||
protected-route navigation, so we need a lightweight read of a single
|
||||
boolean rather than paying for the full settings aggregation.
|
||||
"""
|
||||
user_auth = cast(SaasUserAuth, await get_user_auth(request))
|
||||
user_id = await user_auth.get_user_id()
|
||||
|
||||
if not user_id:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User is not authenticated'},
|
||||
)
|
||||
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'User not found'},
|
||||
)
|
||||
|
||||
should_complete = await _should_redirect_to_onboarding(user_id, user)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'should_complete_onboarding': should_complete},
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('/complete_onboarding')
|
||||
async def complete_onboarding(request: Request):
|
||||
"""Mark onboarding as completed for the current user."""
|
||||
|
||||
@@ -2,7 +2,15 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
from integrations.gitlab.gitlab_manager import GitlabManager
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
@@ -15,12 +23,15 @@ from integrations.models import Message, SourceType
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import GITLAB_WEBHOOK_URL, IS_LOCAL_DEPLOYMENT
|
||||
from pydantic import BaseModel
|
||||
from server.auth.constants import AUTOMATION_EVENT_FORWARDING_ENABLED
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
from storage.gitlab_webhook import GitlabWebhook
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.server.shared import sio
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
@@ -29,6 +40,7 @@ webhook_store = GitlabWebhookStore()
|
||||
|
||||
token_manager = TokenManager()
|
||||
gitlab_manager = GitlabManager(token_manager)
|
||||
automation_event_service = AutomationEventService(token_manager)
|
||||
|
||||
|
||||
# Request/Response models
|
||||
@@ -82,6 +94,7 @@ async def verify_gitlab_signature(
|
||||
@gitlab_integration_router.post('/gitlab/events')
|
||||
async def gitlab_events(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_gitlab_token: str = Header(None),
|
||||
x_openhands_webhook_id: str = Header(None),
|
||||
x_openhands_user_id: str = Header(None),
|
||||
@@ -112,6 +125,16 @@ async def gitlab_events(
|
||||
content={'message': 'Duplicate GitLab event ignored.'},
|
||||
)
|
||||
|
||||
# Forward to automation service (fire-and-forget background task)
|
||||
if AUTOMATION_EVENT_FORWARDING_ENABLED:
|
||||
background_tasks.add_task(
|
||||
automation_event_service.forward_event,
|
||||
provider=ProviderType.GITLAB,
|
||||
payload=payload_data,
|
||||
installation_id=x_openhands_webhook_id,
|
||||
)
|
||||
|
||||
# Existing resolver bot processing
|
||||
message = Message(
|
||||
source=SourceType.GITLAB,
|
||||
message={
|
||||
|
||||
@@ -180,6 +180,18 @@ async def device_token(device_code: str = Form(...)):
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'authorized':
|
||||
# Verify user_id is set (should always be true for authorized status)
|
||||
if not device_code_entry.keycloak_user_id:
|
||||
logger.error(
|
||||
'Authorized device code missing user_id',
|
||||
extra={'user_code': device_code_entry.user_code},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'User identification missing',
|
||||
)
|
||||
|
||||
# Retrieve the specific API key for this device using the user_code
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
|
||||
|
||||
@@ -163,7 +163,8 @@ class AutomationEventService:
|
||||
org_id = await self._resolve_git_org(provider, git_org_name)
|
||||
|
||||
# Fallback for personal repos (owner_type indicates individual user)
|
||||
if not org_id and owner_type == 'User':
|
||||
# GitHub uses 'User', GitLab uses 'user'
|
||||
if not org_id and owner_type and owner_type.lower() == 'user':
|
||||
org_id = await self._resolve_personal_org(provider, owner_id)
|
||||
if org_id:
|
||||
logger.info(
|
||||
@@ -206,6 +207,18 @@ class AutomationEventService:
|
||||
owner = repo.get('owner', {})
|
||||
return owner.get('login'), owner.get('type'), owner.get('id')
|
||||
|
||||
if provider == ProviderType.GITLAB:
|
||||
# GitLab uses 'project' instead of 'repository'
|
||||
# path_with_namespace is like "org-name/repo-name" or "user-name/repo-name"
|
||||
project = payload.get('project', {})
|
||||
path_with_namespace = project.get('path_with_namespace', '')
|
||||
git_org = path_with_namespace.split('/')[0] if path_with_namespace else None
|
||||
namespace = project.get('namespace', {})
|
||||
# GitLab uses 'group' for organizations and 'user' for personal projects
|
||||
owner_type = namespace.get('kind')
|
||||
owner_id = namespace.get('id')
|
||||
return git_org, owner_type, owner_id
|
||||
|
||||
logger.warning(f'Unsupported provider ({provider.value})')
|
||||
return None, None, None
|
||||
|
||||
|
||||
@@ -350,8 +350,8 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
user_query = select(User).where(User.id == user_id_uuid)
|
||||
result = await self.db_session.execute(user_query)
|
||||
user = result.scalar_one_or_none()
|
||||
user_result = await self.db_session.execute(user_query)
|
||||
user = user_result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Determine org_id: prefer API key's org_id if authenticated via API key
|
||||
@@ -372,8 +372,8 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
saas_result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = saas_result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == org_id
|
||||
|
||||
@@ -138,7 +138,8 @@ class VerifiedModelService:
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalars().first()
|
||||
stored = result.scalars().first()
|
||||
return verified_model(stored) if stored else None
|
||||
|
||||
async def create_verified_model(
|
||||
self,
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from sqlalchemy import and_, asc, select, text, update
|
||||
from sqlalchemy import and_, asc, delete, select, text, update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from storage.database import a_session_maker
|
||||
from storage.gitlab_webhook import GitlabWebhook
|
||||
@@ -25,6 +25,8 @@ class GitlabWebhookStore:
|
||||
|
||||
if webhook.group_id:
|
||||
return (GitLabResourceType.GROUP, webhook.group_id)
|
||||
# At this point, project_id must be set (we checked at least one is set above)
|
||||
assert webhook.project_id is not None
|
||||
return (GitLabResourceType.PROJECT, webhook.project_id)
|
||||
|
||||
async def store_webhooks(self, project_details: list[GitlabWebhook]) -> None:
|
||||
@@ -123,11 +125,11 @@ class GitlabWebhookStore:
|
||||
async with session.begin():
|
||||
# Create query based on the identifier provided
|
||||
if resource_type == GitLabResourceType.PROJECT:
|
||||
query = GitlabWebhook.__table__.delete().where(
|
||||
query = delete(GitlabWebhook).where(
|
||||
GitlabWebhook.project_id == resource_id
|
||||
)
|
||||
else: # has_group_id must be True based on validation
|
||||
query = GitlabWebhook.__table__.delete().where(
|
||||
query = delete(GitlabWebhook).where(
|
||||
GitlabWebhook.group_id == resource_id
|
||||
)
|
||||
|
||||
|
||||
@@ -402,9 +402,7 @@ class LiteLlmManager:
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
# Update user_settings with the new key so it gets stored in org_member
|
||||
# agent_settings is a JSON column (dict) on UserSettings
|
||||
if user_settings.agent_settings is None:
|
||||
user_settings.agent_settings = {}
|
||||
# agent_settings is a non-nullable JSON column (dict) on UserSettings
|
||||
user_settings.agent_settings.setdefault('llm', {})[
|
||||
'api_key'
|
||||
] = new_key
|
||||
|
||||
@@ -35,10 +35,10 @@ class OrgAppSettingsStore:
|
||||
Org: The organization object, or None if not found
|
||||
"""
|
||||
# Get user with their current_org_id
|
||||
result = await self.db_session.execute(
|
||||
user_result = await self.db_session.execute(
|
||||
select(User).filter(User.id == UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
user = user_result.scalars().first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
@@ -48,8 +48,8 @@ class OrgAppSettingsStore:
|
||||
return None
|
||||
|
||||
# Get the organization
|
||||
result = await self.db_session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
org_result = await self.db_session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = org_result.scalars().first()
|
||||
|
||||
if not org:
|
||||
return None
|
||||
|
||||
@@ -4,10 +4,13 @@ import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from typing import TYPE_CHECKING, Callable, ContextManager
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
@@ -31,14 +34,14 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
session_maker: Callable[[], ContextManager[Session]]
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
session_maker: sessionmaker,
|
||||
org_id: UUID | None,
|
||||
session_maker: Callable[[], ContextManager[Session]],
|
||||
resolver_org_id: UUID | None = None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
@@ -65,7 +68,9 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
return query
|
||||
|
||||
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
|
||||
def _to_external_model(
|
||||
self, conversation_metadata: StoredConversationMetadata
|
||||
) -> ConversationMetadata:
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
@@ -216,7 +221,7 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _search():
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
stored_conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
@@ -233,13 +238,16 @@ class SaasConversationStore(ConversationStore):
|
||||
.limit(limit + 1)
|
||||
.all()
|
||||
)
|
||||
conversations = [self._to_external_model(c) for c in conversations]
|
||||
conversations = [
|
||||
self._to_external_model(c) for c in stored_conversations
|
||||
]
|
||||
current_page_size = len(conversations)
|
||||
next_page_id = offset_to_page_id(
|
||||
offset + limit, current_page_size > limit
|
||||
)
|
||||
conversations = conversations[:limit]
|
||||
return ConversationMetadataResultSet(conversations, next_page_id)
|
||||
return ConversationMetadataResultSet(
|
||||
conversations[:limit], next_page_id
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_search)
|
||||
|
||||
|
||||
@@ -60,13 +60,11 @@ class SaasSecretsStore(SecretsStore):
|
||||
async with a_session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete existing records for this user AND organization only
|
||||
# Note: user.current_org_id is non-nullable, so org_id is always set
|
||||
delete_query = delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id,
|
||||
StoredCustomSecrets.org_id == org_id,
|
||||
)
|
||||
if org_id is not None:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
else:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id.is_(None))
|
||||
await session.execute(delete_query)
|
||||
|
||||
# Prepare the new secrets data
|
||||
|
||||
@@ -4,10 +4,8 @@ Tests for:
|
||||
- _should_redirect_to_onboarding() function
|
||||
- _get_post_auth_redirect() function
|
||||
- /complete_onboarding endpoint
|
||||
- /onboarding_status endpoint
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -19,7 +17,6 @@ from server.routes.auth import (
|
||||
_get_post_auth_redirect,
|
||||
_should_redirect_to_onboarding,
|
||||
complete_onboarding,
|
||||
onboarding_status,
|
||||
)
|
||||
from storage.user import User
|
||||
|
||||
@@ -331,78 +328,3 @@ class TestCompleteOnboardingEndpoint:
|
||||
await complete_onboarding(mock_request)
|
||||
|
||||
mock_mark_completed.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
class TestOnboardingStatusEndpoint:
|
||||
"""Tests for the /onboarding_status API endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_401_when_not_authenticated(self, mock_request):
|
||||
"""Unauthenticated requests return 401."""
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
):
|
||||
result = await onboarding_status(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_true_for_new_cloud_user(self, mock_request, mock_user):
|
||||
"""A cloud user whose onboarding is incomplete should be told to complete it."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_user.onboarding_completed = False
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'),
|
||||
):
|
||||
result = await onboarding_status(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
body = json.loads(result.body)
|
||||
assert body == {'should_complete_onboarding': True}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_for_completed_user(self, mock_request, mock_user):
|
||||
"""A user who already completed onboarding should not be told to complete it."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_user.onboarding_completed = True
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
result = await onboarding_status(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
body = json.loads(result.body)
|
||||
assert body == {'should_complete_onboarding': False}
|
||||
|
||||
@@ -826,3 +826,249 @@ class TestCacheHelpers:
|
||||
service = create_service(mock_token_manager)
|
||||
# Should not raise
|
||||
await service._set_cached_value('test-key', 'test-value', 3600)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GitLab-specific tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gitlab_group_payload():
|
||||
"""Create a sample GitLab webhook payload for a group-owned project."""
|
||||
return {
|
||||
'object_kind': 'issue',
|
||||
'event_type': 'issue',
|
||||
'project': {
|
||||
'id': 123456,
|
||||
'path_with_namespace': 'test-org/test-repo',
|
||||
'visibility': 'public',
|
||||
'namespace': {
|
||||
'id': 789,
|
||||
'kind': 'group',
|
||||
'name': 'test-org',
|
||||
},
|
||||
},
|
||||
'user': {
|
||||
'id': 12345,
|
||||
'username': 'testuser',
|
||||
},
|
||||
'object_attributes': {
|
||||
'id': 1,
|
||||
'iid': 1,
|
||||
'title': 'Test Issue',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gitlab_user_payload():
|
||||
"""Create a sample GitLab webhook payload for a user-owned project."""
|
||||
return {
|
||||
'object_kind': 'issue',
|
||||
'event_type': 'issue',
|
||||
'project': {
|
||||
'id': 654321,
|
||||
'path_with_namespace': 'testuser/personal-repo',
|
||||
'visibility': 'private',
|
||||
'namespace': {
|
||||
'id': 12345,
|
||||
'kind': 'user',
|
||||
'name': 'testuser',
|
||||
},
|
||||
},
|
||||
'user': {
|
||||
'id': 12345,
|
||||
'username': 'testuser',
|
||||
},
|
||||
'object_attributes': {
|
||||
'id': 2,
|
||||
'iid': 1,
|
||||
'title': 'Personal Issue',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestExtractOwnerInfoGitLab:
|
||||
"""Tests for _extract_owner_info method with GitLab payloads."""
|
||||
|
||||
def test_extract_gitlab_group_owner(self, mock_token_manager, gitlab_group_payload):
|
||||
"""
|
||||
GIVEN: GitLab payload for a group-owned project
|
||||
WHEN: _extract_owner_info is called
|
||||
THEN: Returns correct git_org, owner_type, owner_id
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
git_org, owner_type, owner_id = service._extract_owner_info(
|
||||
ProviderType.GITLAB, gitlab_group_payload
|
||||
)
|
||||
|
||||
assert git_org == 'test-org'
|
||||
assert owner_type == 'group'
|
||||
assert owner_id == 789
|
||||
|
||||
def test_extract_gitlab_user_owner(self, mock_token_manager, gitlab_user_payload):
|
||||
"""
|
||||
GIVEN: GitLab payload for a user-owned project
|
||||
WHEN: _extract_owner_info is called
|
||||
THEN: Returns correct git_org, owner_type, owner_id
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
git_org, owner_type, owner_id = service._extract_owner_info(
|
||||
ProviderType.GITLAB, gitlab_user_payload
|
||||
)
|
||||
|
||||
assert git_org == 'testuser'
|
||||
assert owner_type == 'user'
|
||||
assert owner_id == 12345
|
||||
|
||||
def test_extract_gitlab_missing_project(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: GitLab payload without project data
|
||||
WHEN: _extract_owner_info is called
|
||||
THEN: Returns None values
|
||||
"""
|
||||
payload = {'object_kind': 'issue'}
|
||||
|
||||
with patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
git_org, owner_type, owner_id = service._extract_owner_info(
|
||||
ProviderType.GITLAB, payload
|
||||
)
|
||||
|
||||
assert git_org is None
|
||||
assert owner_type is None
|
||||
assert owner_id is None
|
||||
|
||||
|
||||
class TestResolveOrgContextGitLab:
|
||||
"""Tests for _resolve_org_context method with GitLab payloads."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_group_org(
|
||||
self, mock_token_manager, mock_org_git_claim, gitlab_group_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab payload for a group project with claimed org
|
||||
WHEN: _resolve_org_context is called
|
||||
THEN: Returns correct OrgContext with org_id
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_org_context(
|
||||
ProviderType.GITLAB, gitlab_group_payload
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.org_id == mock_org_git_claim.org_id
|
||||
assert result.git_org == 'test-org'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_user_personal_org_fallback(
|
||||
self, mock_token_manager, mock_org_git_claim, gitlab_user_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab payload for a user project (not claimed via OrgGitClaim)
|
||||
WHEN: _resolve_org_context is called
|
||||
THEN: Falls back to personal org resolution
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # No OrgGitClaim for user
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
# Mock _resolve_personal_org to return a personal org
|
||||
service._resolve_personal_org = AsyncMock(
|
||||
return_value=mock_org_git_claim.org_id
|
||||
)
|
||||
|
||||
result = await service._resolve_org_context(
|
||||
ProviderType.GITLAB, gitlab_user_payload
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.org_id == mock_org_git_claim.org_id
|
||||
assert result.git_org == 'testuser'
|
||||
# Verify personal org fallback was called with GitLab user ID
|
||||
service._resolve_personal_org.assert_called_once_with(
|
||||
ProviderType.GITLAB, 12345
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_no_org_found(
|
||||
self, mock_token_manager, gitlab_group_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab payload for a project with no org claim and no personal org
|
||||
WHEN: _resolve_org_context is called
|
||||
THEN: Returns None
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_org_context(
|
||||
ProviderType.GITLAB, gitlab_group_payload
|
||||
)
|
||||
|
||||
# Group owner doesn't fall back to personal org
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestResolveGitOrgGitLab:
|
||||
"""Tests for _resolve_git_org method with GitLab provider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_gitlab_org_includes_provider_in_cache_key(
|
||||
self, mock_token_manager, mock_org_git_claim
|
||||
):
|
||||
"""
|
||||
GIVEN: GitLab provider with an org name
|
||||
WHEN: _resolve_git_org is called
|
||||
THEN: Cache key includes the gitlab provider name
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
await service._resolve_git_org(ProviderType.GITLAB, 'Test-Org')
|
||||
|
||||
# Verify cache key includes provider and normalized org name
|
||||
get_call_args = mock_redis.get.call_args[0][0]
|
||||
assert 'gitlab' in get_call_args
|
||||
assert 'test-org' in get_call_args # Lowercase normalized
|
||||
|
||||
@@ -8,16 +8,23 @@ import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
from server.routes.integration.gitlab import gitlab_events
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_background_tasks():
|
||||
"""Create a mock BackgroundTasks."""
|
||||
return MagicMock(spec=BackgroundTasks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.gitlab.verify_gitlab_signature')
|
||||
@patch('server.routes.integration.gitlab.gitlab_manager')
|
||||
@patch('server.routes.integration.gitlab.sio')
|
||||
async def test_gitlab_events_deduplication_with_object_id(
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature, mock_background_tasks
|
||||
):
|
||||
"""Test that duplicate GitLab events are deduplicated using object_attributes.id."""
|
||||
# Setup mocks
|
||||
@@ -47,6 +54,7 @@ async def test_gitlab_events_deduplication_with_object_id(
|
||||
# Call the endpoint
|
||||
response = await gitlab_events(
|
||||
request=mock_request,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -70,6 +78,7 @@ async def test_gitlab_events_deduplication_with_object_id(
|
||||
# Call the endpoint again with the same payload
|
||||
response = await gitlab_events(
|
||||
request=mock_request,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -92,7 +101,7 @@ async def test_gitlab_events_deduplication_with_object_id(
|
||||
@patch('server.routes.integration.gitlab.gitlab_manager')
|
||||
@patch('server.routes.integration.gitlab.sio')
|
||||
async def test_gitlab_events_deduplication_without_object_id(
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature, mock_background_tasks
|
||||
):
|
||||
"""Test that GitLab events without object_attributes.id are deduplicated using hash of payload."""
|
||||
# Setup mocks
|
||||
@@ -127,6 +136,7 @@ async def test_gitlab_events_deduplication_without_object_id(
|
||||
# Call the endpoint
|
||||
response = await gitlab_events(
|
||||
request=mock_request,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -150,6 +160,7 @@ async def test_gitlab_events_deduplication_without_object_id(
|
||||
# Call the endpoint again with the same payload
|
||||
response = await gitlab_events(
|
||||
request=mock_request,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -172,7 +183,7 @@ async def test_gitlab_events_deduplication_without_object_id(
|
||||
@patch('server.routes.integration.gitlab.gitlab_manager')
|
||||
@patch('server.routes.integration.gitlab.sio')
|
||||
async def test_gitlab_events_different_payloads_not_deduplicated(
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature, mock_background_tasks
|
||||
):
|
||||
"""Test that different GitLab events are not deduplicated."""
|
||||
# Setup mocks
|
||||
@@ -196,6 +207,7 @@ async def test_gitlab_events_different_payloads_not_deduplicated(
|
||||
# Call the endpoint with first payload
|
||||
response1 = await gitlab_events(
|
||||
request=mock_request1,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -223,6 +235,7 @@ async def test_gitlab_events_different_payloads_not_deduplicated(
|
||||
# Call the endpoint with second payload
|
||||
response2 = await gitlab_events(
|
||||
request=mock_request2,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -242,7 +255,7 @@ async def test_gitlab_events_different_payloads_not_deduplicated(
|
||||
@patch('server.routes.integration.gitlab.gitlab_manager')
|
||||
@patch('server.routes.integration.gitlab.sio')
|
||||
async def test_gitlab_events_multiple_identical_payloads_deduplicated(
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature
|
||||
mock_sio, mock_gitlab_manager, mock_verify_signature, mock_background_tasks
|
||||
):
|
||||
"""Test that multiple identical GitLab events are properly deduplicated."""
|
||||
# Setup mocks
|
||||
@@ -273,6 +286,7 @@ async def test_gitlab_events_multiple_identical_payloads_deduplicated(
|
||||
# Call the endpoint first time
|
||||
response1 = await gitlab_events(
|
||||
request=mock_request,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -298,6 +312,7 @@ async def test_gitlab_events_multiple_identical_payloads_deduplicated(
|
||||
# Call the endpoint second time with the same payload
|
||||
response2 = await gitlab_events(
|
||||
request=mock_request,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
@@ -321,6 +336,7 @@ async def test_gitlab_events_multiple_identical_payloads_deduplicated(
|
||||
# Call the endpoint third time with the same payload
|
||||
response3 = await gitlab_events(
|
||||
request=mock_request,
|
||||
background_tasks=mock_background_tasks,
|
||||
x_gitlab_token='test_token',
|
||||
x_openhands_webhook_id='test_webhook_id',
|
||||
x_openhands_user_id='test_user_id',
|
||||
|
||||
@@ -50,41 +50,6 @@ def test_user_settings_are_split_into_agent_and_conversation_buckets():
|
||||
}
|
||||
|
||||
|
||||
def test_user_settings_normalize_legacy_mcp_config():
|
||||
row = {
|
||||
'agent': 'CodeActAgent',
|
||||
'max_iterations': 42,
|
||||
'security_analyzer': 'llm',
|
||||
'confirmation_mode': True,
|
||||
'llm_model': 'anthropic/claude-sonnet-4-5-20250929',
|
||||
'llm_base_url': 'https://api.example.com',
|
||||
'enable_default_condenser': False,
|
||||
'condenser_max_size': 128,
|
||||
'mcp_config': {
|
||||
'sse_servers': [],
|
||||
'stdio_servers': [],
|
||||
'shttp_servers': [
|
||||
{'url': 'https://mcp.example.com', 'api_key': None, 'timeout': 60}
|
||||
],
|
||||
},
|
||||
'agent_settings': {},
|
||||
'conversation_settings': {},
|
||||
}
|
||||
|
||||
assert migration_108._build_user_agent_settings(row) == {
|
||||
'schema_version': 1,
|
||||
'agent': 'CodeActAgent',
|
||||
'llm': {
|
||||
'model': 'anthropic/claude-sonnet-4-5-20250929',
|
||||
'base_url': 'https://api.example.com',
|
||||
},
|
||||
'condenser': {'enabled': False, 'max_size': 128},
|
||||
'mcp_config': {
|
||||
'mcpServers': {'shttp': {'url': 'https://mcp.example.com', 'timeout': 60}}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_org_member_diffs_use_nested_llm_and_conversation_settings():
|
||||
row = {
|
||||
'max_iterations': 50,
|
||||
@@ -111,36 +76,6 @@ def test_org_member_diffs_use_nested_llm_and_conversation_settings():
|
||||
assert conversation_settings_diff == {'max_iterations': 50}
|
||||
|
||||
|
||||
def test_org_member_diffs_normalize_legacy_mcp_config():
|
||||
row = {
|
||||
'max_iterations': 50,
|
||||
'llm_model': 'openhands/claude-3',
|
||||
'llm_base_url': 'https://proxy.example.com',
|
||||
'mcp_config': {
|
||||
'sse_servers': [],
|
||||
'stdio_servers': [],
|
||||
'shttp_servers': [
|
||||
{'url': 'https://mcp.deepwiki.com/mcp', 'api_key': None, 'timeout': 60}
|
||||
],
|
||||
},
|
||||
'agent_settings_diff': {},
|
||||
'conversation_settings_diff': {},
|
||||
}
|
||||
|
||||
assert migration_108._build_org_member_agent_settings_diff(row) == {
|
||||
'schema_version': 1,
|
||||
'llm': {
|
||||
'model': 'openhands/claude-3',
|
||||
'base_url': 'https://proxy.example.com',
|
||||
},
|
||||
'mcp_config': {
|
||||
'mcpServers': {
|
||||
'shttp': {'url': 'https://mcp.deepwiki.com/mcp', 'timeout': 60}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_org_settings_are_split_into_agent_and_conversation_buckets():
|
||||
row = {
|
||||
'agent': 'CodeActAgent',
|
||||
@@ -206,42 +141,6 @@ def test_downgrade_extracts_legacy_values_from_nested_settings():
|
||||
}
|
||||
|
||||
|
||||
def test_downgrade_restores_legacy_mcp_config_from_sdk_settings():
|
||||
row = {
|
||||
'agent_settings_diff': {
|
||||
'schema_version': 1,
|
||||
'mcp_config': {
|
||||
'mcpServers': {
|
||||
'sse': {'url': 'https://mcp.example.com', 'transport': 'sse'},
|
||||
'shttp': {
|
||||
'url': 'https://mcp.deepwiki.com/mcp',
|
||||
'timeout': 60,
|
||||
},
|
||||
'deepwiki-stdio': {
|
||||
'command': 'npx',
|
||||
'args': ['-y', 'deepwiki-mcp'],
|
||||
'env': {'A': 'B'},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
'conversation_settings_diff': {},
|
||||
}
|
||||
|
||||
assert migration_108._legacy_org_member_values(row)['mcp_config'] == {
|
||||
'sse_servers': [{'url': 'https://mcp.example.com'}],
|
||||
'stdio_servers': [
|
||||
{
|
||||
'name': 'deepwiki-stdio',
|
||||
'command': 'npx',
|
||||
'args': ['-y', 'deepwiki-mcp'],
|
||||
'env': {'A': 'B'},
|
||||
}
|
||||
],
|
||||
'shttp_servers': [{'url': 'https://mcp.deepwiki.com/mcp', 'timeout': 60}],
|
||||
}
|
||||
|
||||
|
||||
def test_migrated_payload_loads_via_user_settings_to_settings():
|
||||
row = {
|
||||
'agent': 'CodeActAgent',
|
||||
|
||||
@@ -6,8 +6,6 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { I18nextProvider } from "react-i18next";
|
||||
import i18n from "i18next";
|
||||
import OnboardingForm, { clientLoader } from "#/routes/onboarding-form";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import { onboardingService } from "#/api/onboarding-service/onboarding-service.api";
|
||||
|
||||
const mockMutate = vi.fn();
|
||||
const mockNavigate = vi.fn();
|
||||
@@ -58,12 +56,6 @@ vi.mock("#/api/option-service/option-service.api", () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock feature flag - enable onboarding by default for tests
|
||||
const mockEnableOnboarding = vi.fn(() => true);
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_ONBOARDING: () => mockEnableOnboarding(),
|
||||
}));
|
||||
|
||||
const renderOnboardingForm = async () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
@@ -563,64 +555,14 @@ describe("OnboardingForm - Self-Hosted Mode", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("OnboardingForm - redirect when already onboarded", () => {
|
||||
beforeEach(() => {
|
||||
mockMutate.mockClear();
|
||||
mockNavigate.mockClear();
|
||||
mockUseMe.mockReturnValue({ data: { role: "member" } });
|
||||
loaderData = {
|
||||
config: {
|
||||
app_mode: "saas",
|
||||
feature_flags: { deployment_mode: "cloud" },
|
||||
},
|
||||
};
|
||||
mockGetConfig.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
feature_flags: { deployment_mode: "cloud" },
|
||||
});
|
||||
vi.spyOn(AuthService, "authenticate").mockResolvedValue(true);
|
||||
});
|
||||
|
||||
it("should navigate to / when the backend reports onboarding is already complete", async () => {
|
||||
// Arrange
|
||||
vi.spyOn(onboardingService, "getStatus").mockResolvedValue({
|
||||
should_complete_onboarding: false,
|
||||
});
|
||||
|
||||
// Act
|
||||
await renderOnboardingForm();
|
||||
|
||||
// Assert
|
||||
await vi.waitFor(() => {
|
||||
expect(mockNavigate).toHaveBeenCalledWith("/", { replace: true });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("onboarding-form clientLoader", () => {
|
||||
beforeEach(() => {
|
||||
mockQueryClientGetData.mockReset();
|
||||
mockQueryClientSetData.mockReset();
|
||||
mockGetConfig.mockReset();
|
||||
mockEnableOnboarding.mockReturnValue(true);
|
||||
});
|
||||
|
||||
describe("redirect behavior", () => {
|
||||
it("should redirect to / when ENABLE_ONBOARDING feature flag is false", async () => {
|
||||
mockEnableOnboarding.mockReturnValue(false);
|
||||
const saasConfig = {
|
||||
app_mode: "saas",
|
||||
feature_flags: { deployment_mode: "cloud" },
|
||||
};
|
||||
mockQueryClientGetData.mockReturnValue(saasConfig);
|
||||
|
||||
const result = await clientLoader();
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect((result as Response).status).toBe(302);
|
||||
expect((result as Response).headers.get("Location")).toBe("/");
|
||||
});
|
||||
|
||||
it("should redirect to / when app_mode is oss", async () => {
|
||||
const ossConfig = {
|
||||
app_mode: "oss",
|
||||
|
||||
@@ -27,6 +27,12 @@ describe("useIsOnIntermediatePage", () => {
|
||||
expect(result.current).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true when on /onboarding page", () => {
|
||||
useLocationMock.mockReturnValue({ pathname: "/onboarding" });
|
||||
const { result } = renderHook(() => useIsOnIntermediatePage());
|
||||
expect(result.current).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true when on /information-request page", () => {
|
||||
useLocationMock.mockReturnValue({ pathname: "/information-request" });
|
||||
const { result } = renderHook(() => useIsOnIntermediatePage());
|
||||
@@ -46,12 +52,6 @@ describe("useIsOnIntermediatePage", () => {
|
||||
const { result } = renderHook(() => useIsOnIntermediatePage());
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false when on /onboarding page so settings/auth queries can fire", () => {
|
||||
useLocationMock.mockReturnValue({ pathname: "/onboarding" });
|
||||
const { result } = renderHook(() => useIsOnIntermediatePage());
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("handles edge cases", () => {
|
||||
|
||||
@@ -6,7 +6,6 @@ import MainApp from "#/routes/root-layout";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { onboardingService } from "#/api/onboarding-service/onboarding-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
|
||||
vi.mock("#/hooks/use-github-auth-url", () => ({
|
||||
@@ -52,12 +51,6 @@ vi.mock("#/hooks/use-invitation", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock feature flags - enable onboarding for tests that need it
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_ONBOARDING: () => true,
|
||||
ENABLE_AUTOMATIONS: () => false,
|
||||
}));
|
||||
|
||||
function LoginStub() {
|
||||
const [searchParams] = useSearchParams();
|
||||
const emailVerificationRequired =
|
||||
@@ -118,23 +111,6 @@ const RouterStubWithLogin = createRoutesStub([
|
||||
},
|
||||
]);
|
||||
|
||||
const RouterStubWithOnboarding = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="outlet-content" />,
|
||||
path: "/",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="onboarding-page" />,
|
||||
path: "/onboarding",
|
||||
},
|
||||
]);
|
||||
|
||||
const RouterStubWithDeviceVerify = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
@@ -217,10 +193,6 @@ describe("MainApp", () => {
|
||||
MOCK_DEFAULT_USER_SETTINGS,
|
||||
);
|
||||
|
||||
vi.spyOn(onboardingService, "getStatus").mockResolvedValue({
|
||||
should_complete_onboarding: false,
|
||||
});
|
||||
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn(() => null),
|
||||
setItem: vi.fn(),
|
||||
@@ -581,25 +553,4 @@ describe("MainApp", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Onboarding redirect", () => {
|
||||
it("should redirect authenticated SaaS users with incomplete onboarding to /onboarding", async () => {
|
||||
// Arrange: backend reports onboarding still required.
|
||||
vi.spyOn(onboardingService, "getStatus").mockResolvedValue({
|
||||
should_complete_onboarding: true,
|
||||
});
|
||||
|
||||
// Act: render the home page.
|
||||
renderWithLoginStub(RouterStubWithOnboarding, ["/"]);
|
||||
|
||||
// Assert: user lands on /onboarding instead of the home outlet.
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("onboarding-page")).toBeInTheDocument();
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
expect(screen.queryByTestId("outlet-content")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
import { openHands } from "../open-hands-axios";
|
||||
|
||||
export type OnboardingStatusResponse = {
|
||||
should_complete_onboarding: boolean;
|
||||
};
|
||||
|
||||
export const onboardingService = {
|
||||
getStatus: async (): Promise<OnboardingStatusResponse> => {
|
||||
const { data } = await openHands.get<OnboardingStatusResponse>(
|
||||
"/api/onboarding_status",
|
||||
);
|
||||
return data;
|
||||
},
|
||||
};
|
||||
@@ -131,13 +131,6 @@ export interface IOption<T> {
|
||||
value: T;
|
||||
}
|
||||
|
||||
export interface MicroagentContentResponse {
|
||||
content: string;
|
||||
path: string;
|
||||
git_provider: Provider;
|
||||
triggers: string[];
|
||||
}
|
||||
|
||||
export type GetFilesResponse = string[];
|
||||
|
||||
export interface GetFileResponse {
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import React from "react";
|
||||
import { useLocation, useNavigate } from "react-router";
|
||||
import { useOnboardingStatus } from "#/hooks/query/use-onboarding-status";
|
||||
import { ENABLE_ONBOARDING } from "#/utils/feature-flags";
|
||||
|
||||
/**
|
||||
* Forces SaaS users with incomplete onboarding to /onboarding before they can
|
||||
* access any protected route. Mirrors EmailVerificationGuard.
|
||||
*/
|
||||
export function OnboardingGuard({ children }: { children: React.ReactNode }) {
|
||||
const { data, isLoading } = useOnboardingStatus();
|
||||
const navigate = useNavigate();
|
||||
const { pathname } = useLocation();
|
||||
|
||||
React.useEffect(() => {
|
||||
if (isLoading) return;
|
||||
// Only redirect to onboarding if the feature flag is enabled
|
||||
if (
|
||||
ENABLE_ONBOARDING() &&
|
||||
data?.should_complete_onboarding &&
|
||||
pathname !== "/onboarding"
|
||||
) {
|
||||
navigate("/onboarding", { replace: true });
|
||||
}
|
||||
}, [data?.should_complete_onboarding, isLoading, pathname, navigate]);
|
||||
|
||||
return children;
|
||||
}
|
||||
@@ -19,7 +19,6 @@ export const useSubmitOnboarding = () => {
|
||||
},
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
queryClient.invalidateQueries({ queryKey: ["onboarding-status"] });
|
||||
|
||||
const finalRedirectUrl = "/";
|
||||
// Check if the redirect URL is an external URL (starts with http or https)
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { onboardingService } from "#/api/onboarding-service/onboarding-service.api";
|
||||
import { useConfig } from "./use-config";
|
||||
import { useIsAuthed } from "./use-is-authed";
|
||||
|
||||
export const useOnboardingStatus = () => {
|
||||
const { data: config } = useConfig();
|
||||
const { data: isAuthed } = useIsAuthed();
|
||||
|
||||
return useQuery({
|
||||
queryKey: ["onboarding-status"],
|
||||
queryFn: onboardingService.getStatus,
|
||||
enabled: config?.app_mode === "saas" && !!isAuthed,
|
||||
staleTime: 1000 * 60 * 5,
|
||||
gcTime: 1000 * 60 * 15,
|
||||
retry: false,
|
||||
meta: {
|
||||
disableToast: true,
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,6 +1,10 @@
|
||||
import { useLocation } from "react-router";
|
||||
|
||||
const INTERMEDIATE_PAGE_PATHS = ["/accept-tos", "/information-request"];
|
||||
const INTERMEDIATE_PAGE_PATHS = [
|
||||
"/accept-tos",
|
||||
"/onboarding",
|
||||
"/information-request",
|
||||
];
|
||||
|
||||
/**
|
||||
* Checks if the current page is an intermediate page.
|
||||
|
||||
@@ -7,7 +7,6 @@ import { BrandButton } from "#/components/features/settings/brand-button";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import OpenHandsLogoWhite from "#/assets/branding/openhands-logo-white.svg?react";
|
||||
import { useSubmitOnboarding } from "#/hooks/mutation/use-submit-onboarding";
|
||||
import { useOnboardingStatus } from "#/hooks/query/use-onboarding-status";
|
||||
import { useTracking } from "#/hooks/use-tracking";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { useMe } from "#/hooks/query/use-me";
|
||||
@@ -22,14 +21,8 @@ import {
|
||||
} from "#/api/option-service/option.types";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { ENABLE_ONBOARDING } from "#/utils/feature-flags";
|
||||
|
||||
export const clientLoader = async () => {
|
||||
// Check feature flag FIRST (sync) to block access immediately without flash
|
||||
if (!ENABLE_ONBOARDING()) {
|
||||
return redirect("/");
|
||||
}
|
||||
|
||||
let config = queryClient.getQueryData<WebClientConfig>(["web-client-config"]);
|
||||
if (!config) {
|
||||
config = await OptionService.getConfig();
|
||||
@@ -89,22 +82,9 @@ function OnboardingForm() {
|
||||
const loaderData = useLoaderData<typeof clientLoader>();
|
||||
const config = loaderData?.config;
|
||||
const { data: me } = useMe();
|
||||
const { data: onboardingStatus, isLoading: isOnboardingStatusLoading } =
|
||||
useOnboardingStatus();
|
||||
const { mutate: submitOnboarding } = useSubmitOnboarding();
|
||||
const { trackOnboardingCompleted } = useTracking();
|
||||
|
||||
React.useEffect(() => {
|
||||
if (isOnboardingStatusLoading) return;
|
||||
if (onboardingStatus?.should_complete_onboarding === false) {
|
||||
navigate("/", { replace: true });
|
||||
}
|
||||
}, [
|
||||
onboardingStatus?.should_complete_onboarding,
|
||||
isOnboardingStatusLoading,
|
||||
navigate,
|
||||
]);
|
||||
|
||||
const onboardingAppMode: OnboardingAppMode = getOnboardingAppMode(
|
||||
config?.feature_flags?.deployment_mode,
|
||||
);
|
||||
|
||||
@@ -26,7 +26,6 @@ import { useSyncPostHogConsent } from "#/hooks/use-sync-posthog-consent";
|
||||
import { useAutoSelectOrganization } from "#/hooks/use-auto-select-organization";
|
||||
import { LOCAL_STORAGE_KEYS } from "#/utils/local-storage";
|
||||
import { EmailVerificationGuard } from "#/components/features/guards/email-verification-guard";
|
||||
import { OnboardingGuard } from "#/components/features/guards/onboarding-guard";
|
||||
import { AlertBanner } from "#/components/features/alerts/alert-banner";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
@@ -279,11 +278,9 @@ export default function MainApp() {
|
||||
id="root-outlet"
|
||||
className="flex-1 relative overflow-auto custom-scrollbar"
|
||||
>
|
||||
<OnboardingGuard>
|
||||
<EmailVerificationGuard>
|
||||
<Outlet />
|
||||
</EmailVerificationGuard>
|
||||
</OnboardingGuard>
|
||||
<EmailVerificationGuard>
|
||||
<Outlet />
|
||||
</EmailVerificationGuard>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -20,4 +20,3 @@ export const ENABLE_TRAJECTORY_REPLAY = () =>
|
||||
export const ENABLE_SANDBOX_GROUPING = () =>
|
||||
loadFeatureFlag("SANDBOX_GROUPING");
|
||||
export const ENABLE_AUTOMATIONS = () => loadFeatureFlag("AUTOMATIONS");
|
||||
export const ENABLE_ONBOARDING = () => loadFeatureFlag("ONBOARDING");
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from openhands.agent_server.models import OpenHandsModel, SendMessageRequest
|
||||
from openhands.agent_server.utils import OpenHandsUUID, utc_now
|
||||
@@ -175,6 +175,18 @@ class AppConversationStartRequest(OpenHandsModel):
|
||||
),
|
||||
)
|
||||
|
||||
# Secrets passed directly via API at conversation start time
|
||||
secrets: dict[str, SecretStr] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
'Secrets to pass to the conversation. These are merged with any '
|
||||
'existing secrets (from database or git providers), with API-provided '
|
||||
'secrets taking precedence (overriding any existing secret with the same name). '
|
||||
'Keys are secret names (e.g., "MY_API_KEY"), values are the secret values. '
|
||||
'Warning: Providing a secret that already exists will silently override it.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AppConversationUpdateRequest(BaseModel):
|
||||
"""Request model for updating conversation metadata.
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, AsyncGenerator, Sequence, cast
|
||||
@@ -309,6 +310,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace=remote_workspace,
|
||||
selected_repository=request.selected_repository,
|
||||
plugins=request.plugins,
|
||||
api_secrets=request.secrets,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1216,6 +1218,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace: AsyncRemoteWorkspace | None = None,
|
||||
selected_repository: str | None = None,
|
||||
plugins: list[PluginSpec] | None = None,
|
||||
api_secrets: dict[str, SecretStr] | None = None,
|
||||
) -> StartConversationRequest:
|
||||
"""Build a complete StartConversationRequest for a user.
|
||||
|
||||
@@ -1224,6 +1227,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
Server-only overrides (system prompts, LLM tracing metadata,
|
||||
skills, hooks) are applied to the agent after creation.
|
||||
Finally delegates to ``ConversationSettings.create_request()``.
|
||||
|
||||
Args:
|
||||
sandbox: Sandbox information
|
||||
conversation_id: Unique conversation identifier
|
||||
initial_message: Optional initial message to send
|
||||
system_message_suffix: Optional suffix for system message
|
||||
git_provider: Optional git provider type
|
||||
working_dir: Working directory path
|
||||
agent_type: Type of agent (DEFAULT or PLAN)
|
||||
llm_model: Optional specific LLM model to use
|
||||
remote_workspace: Optional remote workspace instance
|
||||
selected_repository: Optional repository name
|
||||
plugins: Optional list of plugins to load
|
||||
api_secrets: Optional secrets passed directly via the API.
|
||||
These are merged with existing secrets (from database
|
||||
and git providers), with API-provided secrets taking
|
||||
precedence.
|
||||
"""
|
||||
user = await self.user_context.get_user_info()
|
||||
|
||||
@@ -1231,8 +1251,28 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
workspace = LocalWorkspace(working_dir=project_dir)
|
||||
|
||||
# --- secrets --------------------------------------------------------
|
||||
# Start with secrets from git providers and database
|
||||
secrets = await self._setup_secrets_for_git_providers(user)
|
||||
|
||||
# Merge API-provided secrets (they take precedence over existing ones)
|
||||
if api_secrets:
|
||||
from openhands.app_server.constants import (
|
||||
validate_secret_name,
|
||||
validate_secrets_dict,
|
||||
)
|
||||
|
||||
# Validate overall dict size limits first
|
||||
# Cast to Mapping for mypy compatibility (Mapping is covariant in value type)
|
||||
validate_secrets_dict(cast('Mapping[str, object]', api_secrets))
|
||||
|
||||
for name, value in api_secrets.items():
|
||||
validate_secret_name(name)
|
||||
if name in secrets:
|
||||
_logger.warning(
|
||||
'API-provided secret %r overrides existing secret', name
|
||||
)
|
||||
secrets[name] = StaticSecret(value=value)
|
||||
|
||||
# --- LLM + MCP -----------------------------------------------------
|
||||
llm, mcp_config = await self._configure_llm_and_mcp(
|
||||
user, llm_model, conversation_id
|
||||
|
||||
@@ -21,7 +21,7 @@ import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
@@ -33,6 +33,7 @@ from sqlalchemy import (
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -523,19 +524,19 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
|
||||
# Rebuild token usage
|
||||
# Rebuild token usage (use 0 as default for nullable int columns)
|
||||
token_usage = TokenUsage(
|
||||
prompt_tokens=stored.prompt_tokens,
|
||||
completion_tokens=stored.completion_tokens,
|
||||
cache_read_tokens=stored.cache_read_tokens,
|
||||
cache_write_tokens=stored.cache_write_tokens,
|
||||
context_window=stored.context_window,
|
||||
per_turn_token=stored.per_turn_token,
|
||||
prompt_tokens=stored.prompt_tokens or 0,
|
||||
completion_tokens=stored.completion_tokens or 0,
|
||||
cache_read_tokens=stored.cache_read_tokens or 0,
|
||||
cache_write_tokens=stored.cache_write_tokens or 0,
|
||||
context_window=stored.context_window or 0,
|
||||
per_turn_token=stored.per_turn_token or 0,
|
||||
)
|
||||
|
||||
# Rebuild metrics object
|
||||
# Rebuild metrics object (use 0.0 as default for nullable float columns)
|
||||
metrics = MetricsSnapshot(
|
||||
accumulated_cost=stored.accumulated_cost,
|
||||
accumulated_cost=stored.accumulated_cost or 0.0,
|
||||
max_budget_per_task=stored.max_budget_per_task,
|
||||
accumulated_token_usage=token_usage,
|
||||
)
|
||||
@@ -547,7 +548,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
return AppConversationInfo(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas
|
||||
sandbox_id=stored.sandbox_id,
|
||||
sandbox_id=sandbox_id, # Use the asserted non-None value
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
git_provider=(
|
||||
@@ -555,7 +556,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
),
|
||||
title=stored.title,
|
||||
trigger=ConversationTrigger(stored.trigger) if stored.trigger else None,
|
||||
pr_number=stored.pr_number,
|
||||
pr_number=stored.pr_number or [],
|
||||
llm_model=stored.llm_model,
|
||||
metrics=metrics,
|
||||
parent_conversation_id=(
|
||||
@@ -599,7 +600,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
)
|
||||
|
||||
# Execute the secure delete query
|
||||
result = await self.db_session.execute(delete_query)
|
||||
result = cast(CursorResult, await self.db_session.execute(delete_query))
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
@@ -19,11 +19,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import Enum, String, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -264,7 +265,7 @@ class SQLAppConversationStartTaskService(AppConversationStartTaskService):
|
||||
StoredAppConversationStartTask.created_by_user_id == self.user_id
|
||||
)
|
||||
|
||||
result = await self.session.execute(delete_query)
|
||||
result = cast(CursorResult, await self.session.execute(delete_query))
|
||||
|
||||
# Return True if any rows were affected
|
||||
return result.rowcount > 0
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
"""Constants for the OpenHands App Server.
|
||||
|
||||
This module contains constants that are used across the app server,
|
||||
including security-related configurations for secret name validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
|
||||
# =============================================================================
|
||||
# SECRET LIMITS (configurable via environment variables)
|
||||
# =============================================================================
|
||||
|
||||
# Maximum number of secrets that can be passed via API in a single request.
|
||||
# Prevents abuse by limiting the size of the secrets dictionary.
|
||||
# Override with: OH_MAX_API_SECRETS_COUNT
|
||||
MAX_API_SECRETS_COUNT: int = int(os.getenv('OH_MAX_API_SECRETS_COUNT', '50'))
|
||||
|
||||
# Maximum length of a secret name in characters.
|
||||
# Environment variable names should be concise; this prevents excessively long names.
|
||||
# Override with: OH_MAX_API_SECRET_NAME_LENGTH
|
||||
MAX_API_SECRET_NAME_LENGTH: int = int(os.getenv('OH_MAX_API_SECRET_NAME_LENGTH', '256'))
|
||||
|
||||
# Maximum length of a secret value in bytes.
|
||||
# 64KB is generous for API keys/tokens while preventing massive payloads.
|
||||
# Override with: OH_MAX_API_SECRET_VALUE_LENGTH
|
||||
MAX_API_SECRET_VALUE_LENGTH: int = int(
|
||||
os.getenv('OH_MAX_API_SECRET_VALUE_LENGTH', '65536')
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SECRET NAME VALIDATION
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# BLOCKED: These names CANNOT be used as user-provided secrets.
|
||||
#
|
||||
# These environment variables are injected into the agent-server container
|
||||
# at startup. User-provided secrets with these names would override them
|
||||
# when exported in bash commands, potentially breaking the sandbox or
|
||||
# creating security vulnerabilities.
|
||||
# -----------------------------------------------------------------------------
|
||||
BLOCKED_SECRET_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
# Agent-server container configuration (from initial_env)
|
||||
'OPENVSCODE_SERVER_ROOT',
|
||||
'OH_ENABLE_VNC',
|
||||
'LOG_JSON',
|
||||
'OH_CONVERSATIONS_PATH',
|
||||
'OH_BASH_EVENTS_DIR',
|
||||
'PYTHONUNBUFFERED',
|
||||
'ENV_LOG_LEVEL',
|
||||
# Webhook and CORS - overriding could redirect callbacks to malicious endpoints
|
||||
'OH_WEBHOOKS_0_BASE_URL',
|
||||
'OH_ALLOW_CORS_ORIGINS_0',
|
||||
# Worker ports - could break web application functionality
|
||||
'WORKER_1',
|
||||
'WORKER_2',
|
||||
}
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# BLOCKED PREFIXES: Secret names starting with these prefixes are blocked.
|
||||
#
|
||||
# LLM_* variables are auto-forwarded to the agent-server container to enforce
|
||||
# LLM controls (timeouts, retries, model restrictions, etc.). Allowing users
|
||||
# to override these would let them escape app-server LLM controls.
|
||||
# -----------------------------------------------------------------------------
|
||||
BLOCKED_SECRET_PREFIXES: tuple[str, ...] = ('LLM_',)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OVERRIDABLE: These are system-provided but users MAY override them.
|
||||
# Documented here for clarity - these are explicitly ALLOWED, not blocked.
|
||||
#
|
||||
# Use case: User wants to use their own credentials instead of the
|
||||
# organization-level credentials provided by the system.
|
||||
# -----------------------------------------------------------------------------
|
||||
OVERRIDABLE_SYSTEM_SECRETS: frozenset[str] = frozenset(
|
||||
{
|
||||
# Git Provider Tokens - users may provide their own credentials
|
||||
# Note: Provider tokens are fetched via app-server API, not container env
|
||||
'GITHUB_TOKEN',
|
||||
'GITLAB_TOKEN',
|
||||
'BITBUCKET_TOKEN',
|
||||
'AZURE_DEVOPS_TOKEN',
|
||||
'FORGEJO_TOKEN',
|
||||
# AWS Credentials - used for Bedrock LLM access
|
||||
# Users may want to use their own AWS account for Bedrock models
|
||||
'AWS_ACCESS_KEY_ID',
|
||||
'AWS_SECRET_ACCESS_KEY',
|
||||
'AWS_REGION_NAME',
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_secret_name(name: str) -> None:
|
||||
"""Validate that a secret name is allowed.
|
||||
|
||||
Args:
|
||||
name: The secret name to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If the name is blocked (exact match or prefix match),
|
||||
or exceeds the maximum length
|
||||
"""
|
||||
# Check name length
|
||||
if len(name) > MAX_API_SECRET_NAME_LENGTH:
|
||||
raise ValueError(
|
||||
f'Secret name exceeds maximum length of {MAX_API_SECRET_NAME_LENGTH} characters '
|
||||
f'(got {len(name)}). Configure via OH_MAX_API_SECRET_NAME_LENGTH.'
|
||||
)
|
||||
|
||||
upper_name = name.upper()
|
||||
|
||||
# Check exact matches
|
||||
if upper_name in BLOCKED_SECRET_NAMES:
|
||||
raise ValueError(
|
||||
f"Secret name '{name}' is reserved for internal use and cannot be overridden. "
|
||||
f'See openhands.app_server.constants for the list of blocked names.'
|
||||
)
|
||||
|
||||
# Check prefix matches
|
||||
for prefix in BLOCKED_SECRET_PREFIXES:
|
||||
if upper_name.startswith(prefix):
|
||||
raise ValueError(
|
||||
f"Secret name '{name}' starts with reserved prefix '{prefix}' and cannot be used. "
|
||||
f'These variables are used for LLM configuration controls.'
|
||||
)
|
||||
|
||||
# Note: OVERRIDABLE_SYSTEM_SECRETS are intentionally allowed
|
||||
|
||||
|
||||
def validate_secrets_dict(secrets: Mapping[str, object] | None) -> None:
|
||||
"""Validate the entire secrets dictionary for size limits.
|
||||
|
||||
This should be called before iterating over individual secrets.
|
||||
|
||||
Args:
|
||||
secrets: The secrets dictionary to validate (can be None).
|
||||
Values can be str or SecretStr (uses get_secret_value()).
|
||||
|
||||
Raises:
|
||||
ValueError: If the dictionary exceeds size limits
|
||||
"""
|
||||
if secrets is None:
|
||||
return
|
||||
|
||||
# Check number of secrets
|
||||
if len(secrets) > MAX_API_SECRETS_COUNT:
|
||||
raise ValueError(
|
||||
f'Too many secrets provided: {len(secrets)} exceeds maximum of '
|
||||
f'{MAX_API_SECRETS_COUNT}. Configure via OH_MAX_API_SECRETS_COUNT.'
|
||||
)
|
||||
|
||||
# Check individual value lengths
|
||||
for name, value in secrets.items():
|
||||
# Handle both str and SecretStr (Pydantic's SecretStr has get_secret_value())
|
||||
if hasattr(value, 'get_secret_value'):
|
||||
value_str = value.get_secret_value() # type: ignore[union-attr]
|
||||
else:
|
||||
value_str = str(value)
|
||||
value_bytes = len(value_str.encode('utf-8'))
|
||||
if value_bytes > MAX_API_SECRET_VALUE_LENGTH:
|
||||
raise ValueError(
|
||||
f"Secret '{name}' value exceeds maximum length of "
|
||||
f'{MAX_API_SECRET_VALUE_LENGTH} bytes (got {value_bytes}). '
|
||||
f'Configure via OH_MAX_API_SECRET_VALUE_LENGTH.'
|
||||
)
|
||||
@@ -29,6 +29,10 @@ from openhands.app_server.config import (
|
||||
)
|
||||
from openhands.app_server.errors import AuthError
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventCallback
|
||||
from openhands.app_server.event_callback.set_title_callback_processor import (
|
||||
SetTitleCallbackProcessor,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.services.jwt_service import JwtService
|
||||
@@ -203,6 +207,9 @@ async def on_conversation_update(
|
||||
if conversation_info.execution_status == ConversationExecutionStatus.DELETING:
|
||||
return Success()
|
||||
|
||||
# Detect if this is a new conversation (stub has title=None)
|
||||
is_new_conversation = existing.title is None
|
||||
|
||||
# Merge tags from incoming conversation info
|
||||
# SDK can set tags via Conversation(tags=...) which includes automation context
|
||||
merged_tags = merge_conversation_tags(existing.tags, conversation_info.tags)
|
||||
@@ -237,6 +244,24 @@ async def on_conversation_update(
|
||||
app_conversation_info
|
||||
)
|
||||
|
||||
# Register SetTitleCallbackProcessor for new conversations created via webhook.
|
||||
# This enables auto-titling for conversations created directly on the agent-server
|
||||
# (e.g., automation runs) that notify the app-server via webhook.
|
||||
if is_new_conversation:
|
||||
state = InjectorState()
|
||||
setattr(
|
||||
state,
|
||||
USER_CONTEXT_ATTR,
|
||||
SpecifyUserContext(sandbox_info.created_by_user_id),
|
||||
)
|
||||
async with get_event_callback_service(state) as event_callback_service:
|
||||
await event_callback_service.save_event_callback(
|
||||
EventCallback(
|
||||
conversation_id=conversation_info.id,
|
||||
processor=SetTitleCallbackProcessor(),
|
||||
)
|
||||
)
|
||||
|
||||
return Success()
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
# The version of the agent server to use for deployments.
|
||||
# Typically this will be the same as the values from the pyproject.toml
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:1.17.0-python'
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:1.18.1-python'
|
||||
|
||||
|
||||
class SandboxSpecService(ABC):
|
||||
|
||||
@@ -24,7 +24,7 @@ DB_SESSION_ATTR = 'db_session'
|
||||
DB_SESSION_KEEP_OPEN_ATTR = 'db_session_keep_open'
|
||||
|
||||
|
||||
class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
class DbSessionInjector(BaseModel, Injector[AsyncSession]):
|
||||
persistence_dir: Path
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
@@ -166,6 +166,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
if self.gcp_db_instance: # GCP environments
|
||||
async_engine = await self._create_async_gcp_engine()
|
||||
else:
|
||||
url: str | URL
|
||||
if self.host:
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
@@ -199,6 +200,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
poolclass=NullPool,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
assert async_engine is not None # Always assigned in either branch above
|
||||
self._async_engine = async_engine
|
||||
return async_engine
|
||||
|
||||
@@ -209,6 +211,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
if self.gcp_db_instance: # GCP environments
|
||||
engine = self._create_gcp_engine()
|
||||
else:
|
||||
url: str | URL
|
||||
if self.host:
|
||||
try:
|
||||
import pg8000 # noqa: F401
|
||||
@@ -234,6 +237,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
pool_recycle=self.pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
assert engine is not None # Always assigned in either branch above
|
||||
self._engine = engine
|
||||
return engine
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
@@ -5,15 +6,14 @@ import yaml
|
||||
from fastapi import APIRouter, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
import openhands
|
||||
from openhands.app_server.utils.dependencies import get_dependencies
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.memory.memory import GLOBAL_MICROAGENTS_DIR, USER_MICROAGENTS_DIR
|
||||
|
||||
router = APIRouter(prefix='/skills', tags=['Skills'], dependencies=get_dependencies())
|
||||
|
||||
# Re-use V0 path constants (single source of truth)
|
||||
GLOBAL_SKILLS_DIR = Path(GLOBAL_MICROAGENTS_DIR)
|
||||
USER_SKILLS_DIR = Path(USER_MICROAGENTS_DIR)
|
||||
GLOBAL_SKILLS_DIR = Path(os.path.dirname(openhands.__file__)) / 'skills'
|
||||
USER_SKILLS_DIR = Path.home() / '.openhands' / 'microagents'
|
||||
|
||||
|
||||
class SkillInfo(BaseModel):
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
|
||||
__all__ = [
|
||||
'AgentController',
|
||||
]
|
||||
@@ -1,85 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.action import Action
|
||||
|
||||
|
||||
class ActionParseError(Exception):
|
||||
"""Exception raised when the response from the LLM cannot be parsed into an action."""
|
||||
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.error
|
||||
|
||||
|
||||
class ResponseParser(ABC):
|
||||
"""This abstract base class is a general interface for an response parser dedicated to
|
||||
parsing the action from the response from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
# Need pay attention to the item order in self.action_parsers
|
||||
self.action_parsers: list[ActionParser] = []
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, response: Any) -> Action:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- response: The response from the LLM, which can be a string or a dictionary.
|
||||
|
||||
Returns:
|
||||
- action (Action): The action parsed from the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response: Any) -> str:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- response: The response from the LLM, which can be a string or a dictionary.
|
||||
|
||||
Returns:
|
||||
- action_str (str): The action str parsed from the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_action(self, action_str: str) -> Action:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- action_str (str): The response from the LLM.
|
||||
|
||||
Returns:
|
||||
- action (Action): The action parsed from the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ActionParser(ABC):
|
||||
"""This abstract base class is a general interface for an action parser dedicated to
|
||||
parsing the action from the action str from the LLM.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def check_condition(self, action_str: str) -> bool:
|
||||
"""Check if the action string can be parsed by this parser."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, action_str: str) -> Action:
|
||||
"""Parses the action from the action string from the LLM response."""
|
||||
pass
|
||||
@@ -1,191 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
# V1 replacement for this module lives in the Software Agent SDK.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.utils.prompt import PromptManager
|
||||
from litellm import ChatCompletionToolParam
|
||||
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.exceptions import (
|
||||
AgentAlreadyRegisteredError,
|
||||
AgentNotRegisteredError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
|
||||
|
||||
class Agent(ABC):
|
||||
DEPRECATED = False
|
||||
"""
|
||||
This abstract base class is an general interface for an agent dedicated to
|
||||
executing a specific instruction and allowing human interaction with the
|
||||
agent during execution.
|
||||
It tracks the execution status and maintains a history of interactions.
|
||||
"""
|
||||
|
||||
_registry: dict[str, type['Agent']] = {}
|
||||
sandbox_plugins: list[PluginRequirement] = []
|
||||
|
||||
config_model: type[AgentConfig] = AgentConfig
|
||||
"""Class field that specifies the config model to use for the agent. Subclasses may override with a derived config model if needed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
):
|
||||
self.llm = llm_registry.get_llm_from_agent_config('agent', config)
|
||||
self.llm_registry = llm_registry
|
||||
self.config = config
|
||||
self._complete = False
|
||||
self._prompt_manager: 'PromptManager' | None = None
|
||||
self.mcp_tools: dict[str, ChatCompletionToolParam] = {}
|
||||
self.tools: list = []
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> 'PromptManager':
|
||||
if self._prompt_manager is None:
|
||||
raise ValueError(f'Prompt manager not initialized for agent {self.name}')
|
||||
return self._prompt_manager
|
||||
|
||||
def get_system_message(self) -> 'SystemMessageAction | None':
|
||||
"""Returns a SystemMessageAction containing the system message and tools.
|
||||
This will be added to the event stream as the first message.
|
||||
|
||||
Returns:
|
||||
SystemMessageAction: The system message action with content and tools
|
||||
None: If there was an error generating the system message
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
try:
|
||||
if not self.prompt_manager:
|
||||
logger.warning(
|
||||
f'[{self.name}] Prompt manager not initialized before getting system message'
|
||||
)
|
||||
return None
|
||||
|
||||
system_message = self.prompt_manager.get_system_message(
|
||||
cli_mode=self.config.cli_mode
|
||||
)
|
||||
|
||||
# Get tools if available
|
||||
tools = getattr(self, 'tools', None)
|
||||
|
||||
system_message_action = SystemMessageAction(
|
||||
content=system_message, tools=tools, agent_class=self.name
|
||||
)
|
||||
# Set the source attribute
|
||||
system_message_action._source = EventSource.AGENT # type: ignore
|
||||
|
||||
return system_message_action
|
||||
except Exception as e:
|
||||
logger.warning(f'[{self.name}] Failed to generate system message: {e}')
|
||||
return None
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Indicates whether the current instruction execution is complete.
|
||||
|
||||
Returns:
|
||||
- complete (bool): True if execution is complete; False otherwise.
|
||||
"""
|
||||
return self._complete
|
||||
|
||||
@abstractmethod
|
||||
def step(self, state: 'State') -> 'Action':
|
||||
"""Starts the execution of the assigned instruction. This method should
|
||||
be implemented by subclasses to define the specific execution logic.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the agent's execution status."""
|
||||
# Only reset the completion status, not the LLM metrics
|
||||
self._complete = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, agent_cls: type['Agent']) -> None:
|
||||
"""Registers an agent class in the registry.
|
||||
|
||||
Parameters:
|
||||
- name (str): The name to register the class under.
|
||||
- agent_cls (Type['Agent']): The class to register.
|
||||
|
||||
Raises:
|
||||
- AgentAlreadyRegisteredError: If name already registered
|
||||
"""
|
||||
if name in cls._registry:
|
||||
raise AgentAlreadyRegisteredError(name)
|
||||
cls._registry[name] = agent_cls
|
||||
|
||||
@classmethod
|
||||
def get_cls(cls, name: str) -> type['Agent']:
|
||||
"""Retrieves an agent class from the registry.
|
||||
|
||||
Parameters:
|
||||
- name (str): The name of the class to retrieve
|
||||
|
||||
Returns:
|
||||
- agent_cls (Type['Agent']): The class registered under the specified name.
|
||||
|
||||
Raises:
|
||||
- AgentNotRegisteredError: If name not registered
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
raise AgentNotRegisteredError(name)
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def list_agents(cls) -> list[str]:
|
||||
"""Retrieves the list of all agent names from the registry.
|
||||
|
||||
Raises:
|
||||
- AgentNotRegisteredError: If no agent is registered
|
||||
"""
|
||||
if not bool(cls._registry):
|
||||
raise AgentNotRegisteredError()
|
||||
return list(cls._registry.keys())
|
||||
|
||||
def set_mcp_tools(self, mcp_tools: list[dict]) -> None:
|
||||
"""Sets the list of MCP tools for the agent.
|
||||
|
||||
Args:
|
||||
- mcp_tools (list[dict]): The list of MCP tools.
|
||||
"""
|
||||
logger.info(
|
||||
f'Setting {len(mcp_tools)} MCP tools for agent {self.name}: {[tool["function"]["name"] for tool in mcp_tools]}'
|
||||
)
|
||||
for tool in mcp_tools:
|
||||
_tool = ChatCompletionToolParam(**tool)
|
||||
if _tool['function']['name'] in self.mcp_tools:
|
||||
logger.warning(
|
||||
f'Tool {_tool["function"]["name"]} already exists, skipping'
|
||||
)
|
||||
continue
|
||||
self.mcp_tools[_tool['function']['name']] = _tool
|
||||
self.tools.append(_tool)
|
||||
logger.info(
|
||||
f'Tools updated for agent {self.name}, total {len(self.tools)}: {[tool["function"]["name"] for tool in self.tools]}'
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,105 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization.event import event_from_dict
|
||||
|
||||
|
||||
class ReplayManager:
|
||||
"""ReplayManager manages the lifecycle of a replay session of a given trajectory.
|
||||
|
||||
Replay manager keeps track of a list of events, replays actions, and ignore
|
||||
messages and observations.
|
||||
|
||||
Note that unexpected or even errorneous results could happen if
|
||||
1) any action is non-deterministic, OR
|
||||
2) if the initial state before the replay session is different from the
|
||||
initial state of the trajectory.
|
||||
"""
|
||||
|
||||
def __init__(self, events: list[Event] | None):
|
||||
replay_events = []
|
||||
for event in events or []:
|
||||
if event.source == EventSource.ENVIRONMENT:
|
||||
# ignore ENVIRONMENT events as they are not issued by
|
||||
# the user or agent, and should not be replayed
|
||||
continue
|
||||
if isinstance(event, NullObservation):
|
||||
# ignore NullObservation
|
||||
continue
|
||||
replay_events.append(event)
|
||||
|
||||
if replay_events:
|
||||
logger.info(f'Replay events loaded, events length = {len(replay_events)}')
|
||||
for index in range(len(replay_events) - 1):
|
||||
event = replay_events[index]
|
||||
if isinstance(event, MessageAction) and event.wait_for_response:
|
||||
# For any message waiting for response that is not the last
|
||||
# event, we override wait_for_response to False, as a response
|
||||
# would have been included in the next event, and we don't
|
||||
# want the user to interfere with the replay process
|
||||
logger.info(
|
||||
'Replay events contains wait_for_response message action, ignoring wait_for_response'
|
||||
)
|
||||
event.wait_for_response = False
|
||||
self.replay_events = replay_events
|
||||
self.replay_mode = bool(replay_events)
|
||||
self.replay_index = 0
|
||||
|
||||
def _replayable(self) -> bool:
|
||||
return (
|
||||
self.replay_events is not None
|
||||
and self.replay_index < len(self.replay_events)
|
||||
and isinstance(self.replay_events[self.replay_index], Action)
|
||||
)
|
||||
|
||||
def should_replay(self) -> bool:
|
||||
"""Whether the controller is in trajectory replay mode, and the replay
|
||||
hasn't finished. Note: after the replay is finished, the user and
|
||||
the agent could continue to message/act.
|
||||
|
||||
This method also moves "replay_index" to the next action, if applicable.
|
||||
"""
|
||||
if not self.replay_mode:
|
||||
return False
|
||||
|
||||
assert self.replay_events is not None
|
||||
while self.replay_index < len(self.replay_events) and not self._replayable():
|
||||
self.replay_index += 1
|
||||
|
||||
return self._replayable()
|
||||
|
||||
def step(self) -> Action:
|
||||
assert self.replay_events is not None
|
||||
event = self.replay_events[self.replay_index]
|
||||
assert isinstance(event, Action)
|
||||
self.replay_index += 1
|
||||
return event
|
||||
|
||||
@staticmethod
|
||||
def get_replay_events(trajectory: list[dict]) -> list[Event]:
|
||||
if not isinstance(trajectory, list):
|
||||
raise ValueError(
|
||||
f'Expected a list in {trajectory}, got {type(trajectory).__name__}'
|
||||
)
|
||||
replay_events = []
|
||||
for item in trajectory:
|
||||
event = event_from_dict(item)
|
||||
if event.source == EventSource.ENVIRONMENT:
|
||||
# ignore ENVIRONMENT events as they are not issued by
|
||||
# the user or agent, and should not be replayed
|
||||
continue
|
||||
# cannot add an event with _id to event stream
|
||||
event._id = None # type: ignore[attr-defined]
|
||||
replay_events.append(event)
|
||||
return replay_events
|
||||
@@ -1,102 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar(
|
||||
'T', int, float
|
||||
) # Type for the value (int for iterations, float for budget)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlFlag(Generic[T]):
|
||||
"""Base class for control flags that manage limits and state transitions."""
|
||||
|
||||
limit_increase_amount: T
|
||||
current_value: T
|
||||
max_value: T
|
||||
headless_mode: bool = False
|
||||
_hit_limit: bool = False
|
||||
|
||||
def reached_limit(self) -> bool:
|
||||
"""Check if the limit has been reached.
|
||||
|
||||
Returns:
|
||||
bool: True if the limit has been reached, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def increase_limit(self, headless_mode: bool) -> None:
|
||||
"""Expand the limit when needed."""
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self):
|
||||
"""Determine the next state based on the current state and mode.
|
||||
|
||||
Returns:
|
||||
ControlFlagState: The next state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class IterationControlFlag(ControlFlag[int]):
|
||||
"""Control flag for managing iteration limits."""
|
||||
|
||||
def reached_limit(self) -> bool:
|
||||
"""Check if the iteration limit has been reached."""
|
||||
self._hit_limit = self.current_value >= self.max_value
|
||||
return self._hit_limit
|
||||
|
||||
def increase_limit(self, headless_mode: bool) -> None:
|
||||
"""Expand the iteration limit by adding the initial value."""
|
||||
if not headless_mode and self._hit_limit:
|
||||
self.max_value += self.limit_increase_amount
|
||||
self._hit_limit = False
|
||||
|
||||
def step(self):
|
||||
if self.reached_limit():
|
||||
raise RuntimeError(
|
||||
f'Agent reached maximum iteration. '
|
||||
f'Current iteration: {self.current_value}, max iteration: {self.max_value}'
|
||||
)
|
||||
|
||||
# Increment the current value
|
||||
self.current_value += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetControlFlag(ControlFlag[float]):
|
||||
"""Control flag for managing budget limits."""
|
||||
|
||||
def reached_limit(self) -> bool:
|
||||
"""Check if the budget limit has been reached."""
|
||||
self._hit_limit = self.current_value >= self.max_value
|
||||
return self._hit_limit
|
||||
|
||||
def increase_limit(self, headless_mode) -> None:
|
||||
"""Expand the budget limit by adding the initial value to the current value."""
|
||||
if self._hit_limit:
|
||||
self.max_value = self.current_value + self.limit_increase_amount
|
||||
self._hit_limit = False
|
||||
|
||||
def step(self):
|
||||
"""Check if we've reached the limit and update state accordingly.
|
||||
|
||||
Note: Unlike IterationControlFlag, this doesn't increment the value
|
||||
as the budget is updated externally.
|
||||
"""
|
||||
if self.reached_limit():
|
||||
current_str = f'{self.current_value:.2f}'
|
||||
max_str = f'{self.max_value:.2f}'
|
||||
raise RuntimeError(
|
||||
f'Agent reached maximum budget for conversation.'
|
||||
f'Current budget: {current_str}, max budget: {max_str}'
|
||||
)
|
||||
@@ -1,318 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import openhands
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
IterationControlFlag,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import AgentFinishAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.view import View
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_agent_state_filename
|
||||
|
||||
RESUMABLE_STATES = [
|
||||
AgentState.RUNNING,
|
||||
AgentState.PAUSED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
]
|
||||
|
||||
|
||||
# NOTE: this is deprecated
|
||||
class TrafficControlState(str, Enum):
|
||||
# default state, no rate limiting
|
||||
NORMAL = 'normal'
|
||||
|
||||
# task paused due to traffic control
|
||||
THROTTLING = 'throttling'
|
||||
|
||||
# traffic control is temporarily paused
|
||||
PAUSED = 'paused'
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Represents the running state of an agent in the OpenHands system, saving data of its operation and memory.
|
||||
|
||||
- Multi-agent/delegate state:
|
||||
- store the task (conversation between the agent and the user)
|
||||
- the subtask (conversation between an agent and the user or another agent)
|
||||
- global and local iterations
|
||||
- delegate levels for multi-agent interactions
|
||||
- almost stuck state
|
||||
|
||||
- Running state of an agent:
|
||||
- current agent state (e.g., LOADING, RUNNING, PAUSED)
|
||||
- traffic control state for rate limiting
|
||||
- confirmation mode
|
||||
- the last error encountered
|
||||
|
||||
- Data for saving and restoring the agent:
|
||||
- save to and restore from a session
|
||||
- serialize with pickle and base64
|
||||
|
||||
- Save / restore data about message history
|
||||
- start and end IDs for events in agent's history
|
||||
- summaries and delegate summaries
|
||||
|
||||
- Metrics:
|
||||
- global metrics for the current task
|
||||
- local metrics for the current subtask
|
||||
|
||||
- Extra data:
|
||||
- additional task-specific data
|
||||
"""
|
||||
|
||||
session_id: str = ''
|
||||
user_id: str | None = None
|
||||
iteration_flag: IterationControlFlag = field(
|
||||
default_factory=lambda: IterationControlFlag(
|
||||
limit_increase_amount=100, current_value=0, max_value=100
|
||||
)
|
||||
)
|
||||
conversation_stats: ConversationStats | None = None
|
||||
budget_flag: BudgetControlFlag | None = None
|
||||
confirmation_mode: bool = False
|
||||
history: list[Event] = field(default_factory=list)
|
||||
inputs: dict = field(default_factory=dict)
|
||||
outputs: dict = field(default_factory=dict)
|
||||
agent_state: AgentState = AgentState.LOADING
|
||||
resume_state: AgentState | None = None
|
||||
|
||||
# root agent has level 0, and every delegate increases the level by one
|
||||
delegate_level: int = 0
|
||||
# start_id and end_id track the range of events in history
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
|
||||
parent_metrics_snapshot: Metrics | None = None
|
||||
parent_iteration: int = 100
|
||||
|
||||
# NOTE: this is used by the controller to track parent's metrics snapshot before delegation
|
||||
# evaluation tasks to store extra data needed to track the progress/state of the task.
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
last_error: str = ''
|
||||
|
||||
# NOTE: deprecated args, kept here temporarily for backwards compatability
|
||||
# Will be remove in 30 days
|
||||
iteration: int | None = None
|
||||
local_iteration: int | None = None
|
||||
max_iterations: int | None = None
|
||||
traffic_control_state: TrafficControlState | None = None
|
||||
local_metrics: Metrics | None = None
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] | None = None
|
||||
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
|
||||
def save_to_session(
|
||||
self, sid: str, file_store: FileStore, user_id: str | None
|
||||
) -> None:
|
||||
conversation_stats = self.conversation_stats
|
||||
self.conversation_stats = None # Don't save conversation stats, handles itself
|
||||
|
||||
pickled = pickle.dumps(self)
|
||||
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
|
||||
encoded = base64.b64encode(pickled).decode('utf-8')
|
||||
try:
|
||||
file_store.write(
|
||||
get_conversation_agent_state_filename(sid, user_id), encoded
|
||||
)
|
||||
|
||||
# see if state is in the old directory on saas/remote use cases and delete it.
|
||||
if user_id:
|
||||
filename = get_conversation_agent_state_filename(sid)
|
||||
try:
|
||||
file_store.delete(filename)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to save state to session: {e}')
|
||||
raise e
|
||||
|
||||
self.conversation_stats = conversation_stats # restore reference
|
||||
|
||||
@staticmethod
|
||||
def restore_from_session(
|
||||
sid: str, file_store: FileStore, user_id: str | None = None
|
||||
) -> 'State':
|
||||
"""Restores the state from the previously saved session."""
|
||||
state: State
|
||||
try:
|
||||
encoded = file_store.read(
|
||||
get_conversation_agent_state_filename(sid, user_id)
|
||||
)
|
||||
pickled = base64.b64decode(encoded)
|
||||
state = pickle.loads(pickled)
|
||||
except FileNotFoundError:
|
||||
# if user_id is provided, we are in a saas/remote use case
|
||||
# and we need to check if the state is in the old directory.
|
||||
if user_id:
|
||||
filename = get_conversation_agent_state_filename(sid)
|
||||
encoded = file_store.read(filename)
|
||||
pickled = base64.b64decode(encoded)
|
||||
state = pickle.loads(pickled)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f'Could not restore state from session file for sid: {sid}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f'Could not restore state from session: {e}')
|
||||
raise e
|
||||
|
||||
# update state
|
||||
if state.agent_state in RESUMABLE_STATES:
|
||||
state.resume_state = state.agent_state
|
||||
else:
|
||||
state.resume_state = None
|
||||
|
||||
# first state after restore
|
||||
state.agent_state = AgentState.LOADING
|
||||
|
||||
# We don't need to clean up deprecated fields here
|
||||
# They will be handled by __getstate__ when the state is saved again
|
||||
|
||||
return state
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# don't pickle history, it will be restored from the event stream
|
||||
state = self.__dict__.copy()
|
||||
state['history'] = []
|
||||
|
||||
# Remove any view caching attributes. They'll be rebuilt frmo the
|
||||
# history after that gets reloaded.
|
||||
state.pop('_history_checksum', None)
|
||||
state.pop('_view', None)
|
||||
|
||||
# Remove deprecated fields before pickling
|
||||
state.pop('iteration', None)
|
||||
state.pop('local_iteration', None)
|
||||
state.pop('max_iterations', None)
|
||||
state.pop('traffic_control_state', None)
|
||||
state.pop('local_metrics', None)
|
||||
state.pop('delegates', None)
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
# Check if we're restoring from an older version (before control flags)
|
||||
is_old_version = 'iteration' in state
|
||||
|
||||
# Convert old iteration tracking to new iteration_flag if needed
|
||||
if is_old_version:
|
||||
# Create iteration_flag from old values
|
||||
max_iterations = state.get('max_iterations', 100)
|
||||
current_iteration = state.get('iteration', 0)
|
||||
|
||||
# Add the iteration_flag to the state
|
||||
state['iteration_flag'] = IterationControlFlag(
|
||||
limit_increase_amount=max_iterations,
|
||||
current_value=current_iteration,
|
||||
max_value=max_iterations,
|
||||
)
|
||||
|
||||
# Update the state
|
||||
self.__dict__.update(state)
|
||||
|
||||
# We keep the deprecated fields for backward compatibility
|
||||
# They will be removed by __getstate__ when the state is saved again
|
||||
|
||||
# make sure we always have the attribute history
|
||||
if not hasattr(self, 'history'):
|
||||
self.history = []
|
||||
|
||||
# Ensure we have default values for new fields if they're missing
|
||||
if not hasattr(self, 'iteration_flag'):
|
||||
self.iteration_flag = IterationControlFlag(
|
||||
limit_increase_amount=100, current_value=0, max_value=100
|
||||
)
|
||||
|
||||
if not hasattr(self, 'budget_flag'):
|
||||
self.budget_flag = None
|
||||
|
||||
def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
last_user_message = None
|
||||
last_user_message_image_urls: list[str] | None = []
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
last_user_message = event.content
|
||||
last_user_message_image_urls = event.image_urls
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
if last_user_message is not None:
|
||||
return last_user_message, None
|
||||
|
||||
return last_user_message, last_user_message_image_urls
|
||||
|
||||
def get_last_agent_message(self) -> MessageAction | None:
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.AGENT:
|
||||
return event
|
||||
return None
|
||||
|
||||
def get_last_user_message(self) -> MessageAction | None:
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
return event
|
||||
return None
|
||||
|
||||
def to_llm_metadata(self, model_name: str, agent_name: str) -> dict:
|
||||
metadata = {
|
||||
'session_id': self.session_id,
|
||||
'trace_version': openhands.__version__,
|
||||
'trace_user_id': self.user_id,
|
||||
'tags': [
|
||||
f'model:{model_name}',
|
||||
f'agent:{agent_name}',
|
||||
f'web_host:{os.environ.get("WEB_HOST", "unspecified")}',
|
||||
f'openhands_version:{openhands.__version__}',
|
||||
],
|
||||
}
|
||||
return metadata
|
||||
|
||||
def get_local_step(self):
|
||||
if not self.parent_iteration:
|
||||
return self.iteration_flag.current_value
|
||||
|
||||
return self.iteration_flag.current_value - self.parent_iteration
|
||||
|
||||
def get_local_metrics(self):
|
||||
if not self.parent_metrics_snapshot:
|
||||
return self.metrics
|
||||
return self.metrics.diff(self.parent_metrics_snapshot)
|
||||
|
||||
@property
|
||||
def view(self) -> View:
|
||||
# Compute a simple checksum from the history to see if we can re-use any
|
||||
# cached view.
|
||||
history_checksum = len(self.history)
|
||||
old_history_checksum = getattr(self, '_history_checksum', -1)
|
||||
|
||||
# If the history has changed, we need to re-create the view and update
|
||||
# the caching.
|
||||
if history_checksum != old_history_checksum:
|
||||
self._history_checksum = history_checksum
|
||||
self._view = View.from_events(self.history)
|
||||
|
||||
return self._view
|
||||
@@ -1,275 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
IterationControlFlag,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import AgentDelegateAction, ChangeAgentStateAction
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization.event import event_to_trajectory
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class StateTracker:
|
||||
"""Manages and synchronizes the state of an agent throughout its lifecycle.
|
||||
|
||||
It is responsible for:
|
||||
1. Maintaining agent state persistence across sessions
|
||||
2. Managing agent history by filtering and tracking relevant events (previously done in the agent controller)
|
||||
3. Synchronizing metrics between the controller and LLM components
|
||||
4. Updating control flags for budget and iteration limits
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sid: str | None, file_store: FileStore | None, user_id: str | None
|
||||
):
|
||||
self.sid = sid
|
||||
self.file_store = file_store
|
||||
self.user_id = user_id
|
||||
|
||||
# filter out events that are not relevant to the agent
|
||||
# so they will not be included in the agent history
|
||||
self.agent_history_filter = EventFilter(
|
||||
exclude_types=(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
exclude_hidden=True,
|
||||
)
|
||||
|
||||
def set_initial_state(
|
||||
self,
|
||||
id: str,
|
||||
state: State | None,
|
||||
conversation_stats: ConversationStats,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None,
|
||||
confirmation_mode: bool = False,
|
||||
) -> None:
|
||||
"""Sets the initial state for the agent, either from the previous session, or from a parent agent, or by creating a new one.
|
||||
|
||||
Args:
|
||||
state: The state to initialize with, or None to create a new state.
|
||||
max_iterations: The maximum number of iterations allowed for the task.
|
||||
confirmation_mode: Whether to enable confirmation mode.
|
||||
"""
|
||||
# state can come from:
|
||||
# - the previous session, in which case it has history
|
||||
# - from a parent agent, in which case it has no history
|
||||
# - None / a new state
|
||||
|
||||
# If state is None, we create a brand new state and still load the event stream so we can restore the history
|
||||
if state is None:
|
||||
self.state = State(
|
||||
session_id=id.removesuffix('-delegate'),
|
||||
user_id=self.user_id,
|
||||
inputs={},
|
||||
conversation_stats=conversation_stats,
|
||||
iteration_flag=IterationControlFlag(
|
||||
limit_increase_amount=max_iterations,
|
||||
current_value=0,
|
||||
max_value=max_iterations,
|
||||
),
|
||||
budget_flag=None
|
||||
if not max_budget_per_task
|
||||
else BudgetControlFlag(
|
||||
limit_increase_amount=max_budget_per_task,
|
||||
current_value=0,
|
||||
max_value=max_budget_per_task,
|
||||
),
|
||||
confirmation_mode=confirmation_mode,
|
||||
)
|
||||
self.state.start_id = 0
|
||||
|
||||
logger.info(
|
||||
f'AgentController {id} - created new state. start_id: {self.state.start_id}'
|
||||
)
|
||||
else:
|
||||
self.state = state
|
||||
if self.state.start_id <= -1:
|
||||
self.state.start_id = 0
|
||||
|
||||
state.conversation_stats = conversation_stats
|
||||
|
||||
def _init_history(self, event_stream: EventStream) -> None:
|
||||
"""Initializes the agent's history from the event stream.
|
||||
|
||||
The history is a list of events that:
|
||||
- Excludes events of types listed in self.filter_out
|
||||
- Excludes events with hidden=True attribute
|
||||
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
||||
- Excludes all events between the action and observation
|
||||
- Includes the delegate action and observation themselves
|
||||
"""
|
||||
# define range of events to fetch
|
||||
# delegates start with a start_id and initially won't find any events
|
||||
# otherwise we're restoring a previous session
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
# sanity check
|
||||
if start_id > end_id + 1:
|
||||
logger.warning(
|
||||
f'start_id {start_id} is greater than end_id + 1 ({end_id + 1}). History will be empty.',
|
||||
)
|
||||
self.state.history = []
|
||||
return
|
||||
|
||||
events: list[Event] = []
|
||||
|
||||
# Get rest of history
|
||||
events_to_add = list(
|
||||
event_stream.search_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter=self.agent_history_filter,
|
||||
)
|
||||
)
|
||||
events.extend(events_to_add)
|
||||
|
||||
# Find all delegate action/observation pairs
|
||||
delegate_ranges: list[tuple[int, int]] = []
|
||||
delegate_action_ids: list[int] = [] # stack of unmatched delegate action IDs
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, AgentDelegateAction):
|
||||
delegate_action_ids.append(event.id)
|
||||
# Note: we can get agent=event.agent and task=event.inputs.get('task','')
|
||||
# if we need to track these in the future
|
||||
|
||||
elif isinstance(event, AgentDelegateObservation):
|
||||
# Match with most recent unmatched delegate action
|
||||
if not delegate_action_ids:
|
||||
logger.warning(
|
||||
f'Found AgentDelegateObservation without matching action at id={event.id}',
|
||||
)
|
||||
continue
|
||||
|
||||
action_id = delegate_action_ids.pop()
|
||||
delegate_ranges.append((action_id, event.id))
|
||||
|
||||
# Filter out events between delegate action/observation pairs
|
||||
if delegate_ranges:
|
||||
filtered_events: list[Event] = []
|
||||
current_idx = 0
|
||||
|
||||
for start_id, end_id in sorted(delegate_ranges):
|
||||
# Add events before delegate range
|
||||
filtered_events.extend(
|
||||
event for event in events[current_idx:] if event.id < start_id
|
||||
)
|
||||
|
||||
# Add delegate action and observation
|
||||
filtered_events.extend(
|
||||
event for event in events if event.id in (start_id, end_id)
|
||||
)
|
||||
|
||||
# Update index to after delegate range
|
||||
current_idx = next(
|
||||
(i for i, e in enumerate(events) if e.id > end_id), len(events)
|
||||
)
|
||||
|
||||
# Add any remaining events after last delegate range
|
||||
filtered_events.extend(events[current_idx:])
|
||||
|
||||
self.state.history = filtered_events
|
||||
else:
|
||||
self.state.history = events
|
||||
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
|
||||
def close(self, event_stream: EventStream):
|
||||
# we made history, now is the time to rewrite it!
|
||||
# the final state.history will be used by external scripts like evals, tests, etc.
|
||||
# history will need to be complete WITH delegates events
|
||||
# like the regular agent history, it does not include:
|
||||
# - 'hidden' events, events with hidden=True
|
||||
# - backend events (the default 'filtered out' types, types in self.filter_out)
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
self.state.history = list(
|
||||
event_stream.search_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter=self.agent_history_filter,
|
||||
)
|
||||
)
|
||||
|
||||
def add_history(self, event: Event):
|
||||
# if the event is not filtered out, add it to the history
|
||||
if self.agent_history_filter.include(event):
|
||||
self.state.history.append(event)
|
||||
|
||||
def get_trajectory(self, include_screenshots: bool = False) -> list[dict]:
|
||||
return [
|
||||
event_to_trajectory(event, include_screenshots)
|
||||
for event in self.state.history
|
||||
]
|
||||
|
||||
def maybe_increase_control_flags_limits(self, headless_mode: bool):
|
||||
# Iteration and budget extensions are independent of each other
|
||||
# An error will be thrown if any one of the control flags have reached or exceeded its limit
|
||||
self.state.iteration_flag.increase_limit(headless_mode)
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.increase_limit(headless_mode)
|
||||
|
||||
def get_metrics_snapshot(self):
|
||||
"""Deep copy of metrics
|
||||
This serves as a snapshot for the parent's metrics at the time a delegate is created
|
||||
It will be stored and used to compute local metrics for the delegate
|
||||
(since delegates now accumulate metrics from where its parent left off)
|
||||
"""
|
||||
return self.state.metrics.copy()
|
||||
|
||||
def save_state(self):
|
||||
"""Save's current state to persistent store"""
|
||||
if self.sid and self.file_store:
|
||||
self.state.save_to_session(self.sid, self.file_store, self.user_id)
|
||||
|
||||
if self.state.conversation_stats:
|
||||
self.state.conversation_stats.save_metrics()
|
||||
|
||||
def run_control_flags(self):
|
||||
"""Performs one step of the control flags"""
|
||||
self.state.iteration_flag.step()
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.step()
|
||||
|
||||
def sync_budget_flag_with_metrics(self):
|
||||
"""Ensures that budget flag is up to date with accumulated costs from llm completions
|
||||
Budget flag will monitor for when budget is exceeded
|
||||
"""
|
||||
# Sync cost across all llm services from llm registry
|
||||
if self.state.budget_flag and self.state.conversation_stats:
|
||||
self.state.budget_flag.current_value = (
|
||||
self.state.conversation_stats.get_combined_metrics().accumulated_cost
|
||||
)
|
||||
@@ -1,488 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import Event, EventSource
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.commands import IPythonRunCellAction
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
class StuckDetector:
|
||||
SYNTAX_ERROR_MESSAGES = [
|
||||
'SyntaxError: unterminated string literal (detected at line',
|
||||
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
||||
'SyntaxError: incomplete input',
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class StuckAnalysis:
|
||||
loop_type: str
|
||||
loop_repeat_times: int
|
||||
loop_start_idx: int # in filtered_history
|
||||
|
||||
def __init__(self, state: State):
|
||||
self.state = state
|
||||
self.stuck_analysis: Optional[StuckDetector.StuckAnalysis] = None
|
||||
|
||||
def is_stuck(self, headless_mode: bool = True) -> bool:
|
||||
"""Checks if the agent is stuck in a loop.
|
||||
|
||||
Args:
|
||||
headless_mode: Matches AgentController's headless_mode.
|
||||
If True: Consider all history (automated/testing)
|
||||
If False: Consider only history after last user message (interactive)
|
||||
|
||||
Returns:
|
||||
bool: True if the agent is stuck in a loop, False otherwise.
|
||||
"""
|
||||
filtered_history_offset = 0
|
||||
if not headless_mode:
|
||||
# In interactive mode, only look at history after the last user message
|
||||
last_user_msg_idx = -1
|
||||
for i, event in enumerate(reversed(self.state.history)):
|
||||
if (
|
||||
isinstance(event, MessageAction)
|
||||
and event.source == EventSource.USER
|
||||
):
|
||||
last_user_msg_idx = len(self.state.history) - i - 1
|
||||
break
|
||||
filtered_history_offset = last_user_msg_idx + 1
|
||||
history_to_check = self.state.history[last_user_msg_idx + 1 :]
|
||||
else:
|
||||
# In headless mode, look at all history
|
||||
history_to_check = self.state.history
|
||||
|
||||
# Filter out user messages and null events
|
||||
filtered_history = [
|
||||
event
|
||||
for event in history_to_check
|
||||
if not (
|
||||
# Filter works elegantly in both modes:
|
||||
# - In headless: actively filters out user messages from full history
|
||||
# - In non-headless: no-op since we already sliced after last user message
|
||||
(isinstance(event, MessageAction) and event.source == EventSource.USER)
|
||||
# there might be some NullAction or NullObservation in the history at least for now
|
||||
or isinstance(event, (NullAction, NullObservation))
|
||||
)
|
||||
]
|
||||
|
||||
# it takes 3 actions minimum to detect a loop, otherwise nothing to do here
|
||||
if len(filtered_history) < 3:
|
||||
return False
|
||||
|
||||
# the first few scenarios detect 3 or 4 repeated steps
|
||||
# prepare the last 4 actions and observations, to check them out
|
||||
last_actions: list[Event] = []
|
||||
last_observations: list[Event] = []
|
||||
|
||||
# retrieve the last four actions and observations starting from the end of history, wherever they are
|
||||
for event in reversed(filtered_history):
|
||||
if isinstance(event, Action) and len(last_actions) < 4:
|
||||
last_actions.append(event)
|
||||
elif isinstance(event, Observation) and len(last_observations) < 4:
|
||||
last_observations.append(event)
|
||||
|
||||
if len(last_actions) == 4 and len(last_observations) == 4:
|
||||
break
|
||||
|
||||
# scenario 1: same action, same observation
|
||||
if self._is_stuck_repeating_action_observation(
|
||||
last_actions, last_observations, filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# scenario 2: same action, errors
|
||||
if self._is_stuck_repeating_action_error(
|
||||
last_actions, last_observations, filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# scenario 3: monologue
|
||||
if self._is_stuck_monologue(filtered_history, filtered_history_offset):
|
||||
return True
|
||||
|
||||
# scenario 4: action, observation pattern on the last six steps
|
||||
if len(filtered_history) >= 6:
|
||||
if self._is_stuck_action_observation_pattern(
|
||||
filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# scenario 5: context window error loop
|
||||
if len(filtered_history) >= 10:
|
||||
if self._is_stuck_context_window_error(
|
||||
filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# Empty stuck_analysis when not stuck
|
||||
self.stuck_analysis = None
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_observation(
|
||||
self,
|
||||
last_actions: list[Event],
|
||||
last_observations: list[Event],
|
||||
filtered_history: list[Event],
|
||||
filtered_history_offset: int = 0,
|
||||
) -> bool:
|
||||
# scenario 1: same action, same observation
|
||||
# it takes 4 actions and 4 observations to detect a loop
|
||||
# assert len(last_actions) == 4 and len(last_observations) == 4
|
||||
|
||||
# Check for a loop of 4 identical action-observation pairs
|
||||
if len(last_actions) == 4 and len(last_observations) == 4:
|
||||
actions_equal = all(
|
||||
self._eq_no_pid(last_actions[0], action) for action in last_actions
|
||||
)
|
||||
observations_equal = all(
|
||||
self._eq_no_pid(last_observations[0], observation)
|
||||
for observation in last_observations
|
||||
)
|
||||
|
||||
if actions_equal and observations_equal:
|
||||
logger.warning('Action, Observation loop detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_observation',
|
||||
loop_repeat_times=4,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_error(
|
||||
self,
|
||||
last_actions: list[Event],
|
||||
last_observations: list[Event],
|
||||
filtered_history: list[Event],
|
||||
filtered_history_offset: int = 0,
|
||||
) -> bool:
|
||||
# scenario 2: same action, errors
|
||||
# it takes 3 actions and 3 observations to detect a loop
|
||||
# check if the last three actions are the same and result in errors
|
||||
|
||||
if len(last_actions) < 3 or len(last_observations) < 3:
|
||||
return False
|
||||
|
||||
# are the last three actions the "same"?
|
||||
if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
|
||||
# and the last three observations are all errors?
|
||||
if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
|
||||
logger.warning('Action, ErrorObservation loop detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_error',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
# or, are the last three observations all IPythonRunCellObservation with SyntaxError?
|
||||
elif all(
|
||||
isinstance(obs, IPythonRunCellObservation)
|
||||
for obs in last_observations[:3]
|
||||
):
|
||||
warning = 'Action, IPythonRunCellObservation loop detected'
|
||||
for error_message in self.SYNTAX_ERROR_MESSAGES:
|
||||
if error_message.startswith(
|
||||
'SyntaxError: unterminated string literal (detected at line'
|
||||
):
|
||||
if self._check_for_consistent_line_error(
|
||||
[
|
||||
obs
|
||||
for obs in last_observations[:3]
|
||||
if isinstance(obs, IPythonRunCellObservation)
|
||||
],
|
||||
error_message,
|
||||
):
|
||||
logger.warning(warning)
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_error',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
elif error_message in (
|
||||
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
||||
'SyntaxError: incomplete input',
|
||||
) and self._check_for_consistent_invalid_syntax(
|
||||
[
|
||||
obs
|
||||
for obs in last_observations[:3]
|
||||
if isinstance(obs, IPythonRunCellObservation)
|
||||
],
|
||||
error_message,
|
||||
):
|
||||
logger.warning(warning)
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_error',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_for_consistent_invalid_syntax(
|
||||
self, observations: list[IPythonRunCellObservation], error_message: str
|
||||
) -> bool:
|
||||
first_lines = []
|
||||
valid_observations = []
|
||||
|
||||
for obs in observations:
|
||||
content = obs.content
|
||||
lines = content.strip().split('\n')
|
||||
|
||||
if len(lines) < 6: # 6 because a real syntax error has at least 6 lines
|
||||
return False
|
||||
|
||||
line1 = lines[0].strip()
|
||||
if not line1.startswith('Cell In[1], line'):
|
||||
return False
|
||||
|
||||
first_lines.append(line1) # Store the first line of each observation
|
||||
|
||||
# Check last three lines
|
||||
if (
|
||||
lines[-1].startswith('[Jupyter Python interpreter:')
|
||||
and lines[-2].startswith('[Jupyter current working directory:')
|
||||
and error_message in lines[-3]
|
||||
):
|
||||
valid_observations.append(obs)
|
||||
|
||||
# Check if:
|
||||
# 1. All first lines are identical
|
||||
# 2. We have exactly 3 valid observations
|
||||
# 3. The error message line is identical in all valid observations
|
||||
return (
|
||||
len(set(first_lines)) == 1
|
||||
and len(valid_observations) == 3
|
||||
and len(
|
||||
set(
|
||||
obs.content.strip().split('\n')[:-2][-1]
|
||||
for obs in valid_observations
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
|
||||
def _check_for_consistent_line_error(
|
||||
self, observations: list[IPythonRunCellObservation], error_message: str
|
||||
) -> bool:
|
||||
error_lines = []
|
||||
|
||||
for obs in observations:
|
||||
content = obs.content
|
||||
lines = content.strip().split('\n')
|
||||
|
||||
if len(lines) < 3:
|
||||
return False
|
||||
|
||||
last_lines = lines[-3:]
|
||||
|
||||
# Check if the last two lines are our own
|
||||
if not (
|
||||
last_lines[-2].startswith('[Jupyter current working directory:')
|
||||
and last_lines[-1].startswith('[Jupyter Python interpreter:')
|
||||
):
|
||||
return False
|
||||
|
||||
# Check for the error message in the 3rd-to-last line
|
||||
if error_message in last_lines[-3]:
|
||||
error_lines.append(last_lines[-3])
|
||||
|
||||
# Check if we found the error message in all 3 observations
|
||||
# and the 3rd-to-last line is identical across all occurrences
|
||||
return len(error_lines) == 3 and len(set(error_lines)) == 1
|
||||
|
||||
def _is_stuck_monologue(
|
||||
self, filtered_history: list[Event], filtered_history_offset: int = 0
|
||||
) -> bool:
|
||||
# scenario 3: monologue
|
||||
# check for repeated MessageActions with source=AGENT
|
||||
# see if the agent is engaged in a good old monologue, telling itself the same thing over and over
|
||||
agent_message_actions = [
|
||||
(i, event)
|
||||
for i, event in enumerate(filtered_history)
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.AGENT
|
||||
]
|
||||
|
||||
# last three message actions will do for this check
|
||||
if len(agent_message_actions) >= 3:
|
||||
last_agent_message_actions = agent_message_actions[-3:]
|
||||
|
||||
if all(
|
||||
(last_agent_message_actions[0][1] == action[1])
|
||||
for action in last_agent_message_actions
|
||||
):
|
||||
# check if there are any observations between the repeated MessageActions
|
||||
# then it's not yet a loop, maybe it can recover
|
||||
start_index = last_agent_message_actions[0][0]
|
||||
end_index = last_agent_message_actions[-1][0]
|
||||
|
||||
has_observation_between = False
|
||||
for event in filtered_history[start_index + 1 : end_index]:
|
||||
if isinstance(event, Observation):
|
||||
has_observation_between = True
|
||||
break
|
||||
|
||||
if not has_observation_between:
|
||||
logger.warning('Repeated MessageAction with source=AGENT detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='monologue',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=start_index + filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_action_observation_pattern(
|
||||
self, filtered_history: list[Event], filtered_history_offset: int = 0
|
||||
) -> bool:
|
||||
# scenario 4: action, observation pattern on the last six steps
|
||||
# check if the agent repeats the same (Action, Observation)
|
||||
# every other step in the last six steps
|
||||
last_six_actions: list[Event] = []
|
||||
last_six_observations: list[Event] = []
|
||||
|
||||
# the end of history is most interesting
|
||||
for event in reversed(filtered_history):
|
||||
if isinstance(event, Action) and len(last_six_actions) < 6:
|
||||
last_six_actions.append(event)
|
||||
elif isinstance(event, Observation) and len(last_six_observations) < 6:
|
||||
last_six_observations.append(event)
|
||||
|
||||
if len(last_six_actions) == 6 and len(last_six_observations) == 6:
|
||||
break
|
||||
|
||||
# this pattern is every other step, like:
|
||||
# (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
|
||||
if len(last_six_actions) == 6 and len(last_six_observations) == 6:
|
||||
actions_equal = (
|
||||
# action_0 == action_2 == action_4
|
||||
self._eq_no_pid(last_six_actions[0], last_six_actions[2])
|
||||
and self._eq_no_pid(last_six_actions[0], last_six_actions[4])
|
||||
# action_1 == action_3 == action_5
|
||||
and self._eq_no_pid(last_six_actions[1], last_six_actions[3])
|
||||
and self._eq_no_pid(last_six_actions[1], last_six_actions[5])
|
||||
)
|
||||
observations_equal = (
|
||||
# obs_0 == obs_2 == obs_4
|
||||
self._eq_no_pid(last_six_observations[0], last_six_observations[2])
|
||||
and self._eq_no_pid(last_six_observations[0], last_six_observations[4])
|
||||
# obs_1 == obs_3 == obs_5
|
||||
and self._eq_no_pid(last_six_observations[1], last_six_observations[3])
|
||||
and self._eq_no_pid(last_six_observations[1], last_six_observations[5])
|
||||
)
|
||||
|
||||
if actions_equal and observations_equal:
|
||||
logger.warning('Action, Observation pattern detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_observation_pattern',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_six_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_context_window_error(
|
||||
self, filtered_history: list[Event], filtered_history_offset: int = 0
|
||||
) -> bool:
|
||||
"""Detects if we're stuck in a loop of context window errors.
|
||||
|
||||
This happens when we repeatedly get context window errors and try to trim,
|
||||
but the trimming doesn't work, causing us to get more context window errors.
|
||||
The pattern is repeated AgentCondensationObservation events without any other
|
||||
events between them.
|
||||
|
||||
Args:
|
||||
filtered_history: List of filtered events to check
|
||||
|
||||
Returns:
|
||||
bool: True if we detect a context window error loop
|
||||
"""
|
||||
# Look for AgentCondensationObservation events
|
||||
condensation_events = [
|
||||
(i, event)
|
||||
for i, event in enumerate(filtered_history)
|
||||
if isinstance(event, AgentCondensationObservation)
|
||||
]
|
||||
|
||||
# Need at least 10 condensation events to detect a loop
|
||||
if len(condensation_events) < 10:
|
||||
return False
|
||||
|
||||
# Get the last 10 condensation events
|
||||
last_condensation_events = condensation_events[-10:]
|
||||
|
||||
# Check if there are any non-condensation events between them
|
||||
for i in range(len(last_condensation_events) - 1):
|
||||
start_idx = last_condensation_events[i][0]
|
||||
end_idx = last_condensation_events[i + 1][0]
|
||||
|
||||
# Look for any non-condensation events between these two
|
||||
has_other_events = False
|
||||
for event in filtered_history[start_idx + 1 : end_idx]:
|
||||
if not isinstance(event, AgentCondensationObservation):
|
||||
has_other_events = True
|
||||
break
|
||||
|
||||
if not has_other_events:
|
||||
logger.warning(
|
||||
'Context window error loop detected - repeated condensation events'
|
||||
)
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='context_window_error',
|
||||
loop_repeat_times=2,
|
||||
loop_start_idx=start_idx + filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _eq_no_pid(self, obj1: Event, obj2: Event) -> bool:
|
||||
if isinstance(obj1, IPythonRunCellAction) and isinstance(
|
||||
obj2, IPythonRunCellAction
|
||||
):
|
||||
# for loop detection on edit actions, ignore the thought, compare some code
|
||||
# the code should have at least 3 lines, to avoid simple one-liners
|
||||
if (
|
||||
'edit_file_by_replace(' in obj1.code
|
||||
and 'edit_file_by_replace(' in obj2.code
|
||||
):
|
||||
return (
|
||||
len(obj1.code.split('\n')) > 2
|
||||
and obj1.code.split('\n')[:3] == obj2.code.split('\n')[:3]
|
||||
)
|
||||
else:
|
||||
# default comparison
|
||||
return obj1 == obj2
|
||||
elif isinstance(obj1, CmdOutputObservation) and isinstance(
|
||||
obj2, CmdOutputObservation
|
||||
):
|
||||
# for loop detection, ignore command_id, which is the pid
|
||||
return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
|
||||
else:
|
||||
# this is the default comparison
|
||||
return obj1 == obj2
|
||||
@@ -16,7 +16,6 @@ from openhands.core.config.condenser_config import (
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.model_routing_config import ModelRoutingConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
@@ -130,38 +129,4 @@ class AgentConfig(BaseModel):
|
||||
# Still add it to the mapping
|
||||
agent_mapping['agent'] = base_config
|
||||
|
||||
# Process each custom section independently
|
||||
for name, overrides in custom_sections.items():
|
||||
try:
|
||||
# Merge base config with overrides
|
||||
merged = {**base_config.model_dump(), **overrides}
|
||||
if merged.get('classpath'):
|
||||
# if an explicit classpath is given, try to load it and look up its config model class
|
||||
from openhands.controller.agent import Agent
|
||||
|
||||
try:
|
||||
agent_cls = get_impl(Agent, merged.get('classpath'))
|
||||
custom_config = agent_cls.config_model.model_validate(merged)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed to load custom agent class [{merged.get("classpath")}]: {e}. Using default config model.'
|
||||
)
|
||||
custom_config = cls.model_validate(merged)
|
||||
else:
|
||||
# otherwise, try to look up the agent class by name (i.e. if it's a built-in)
|
||||
# if that fails, just use the default AgentConfig class.
|
||||
try:
|
||||
agent_cls = Agent.get_cls(name)
|
||||
custom_config = agent_cls.config_model.model_validate(merged)
|
||||
except Exception:
|
||||
# otherwise, just fall back to the default config model
|
||||
custom_config = cls.model_validate(merged)
|
||||
agent_mapping[name] = custom_config
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f'Invalid agent configuration for [{name}]: {e}. This section will be skipped.'
|
||||
)
|
||||
# Skip this custom section but continue with others
|
||||
continue
|
||||
|
||||
return agent_mapping
|
||||
|
||||
@@ -8,15 +8,6 @@
|
||||
"""Centralized command line argument configuration for OpenHands CLI and headless modes."""
|
||||
|
||||
import argparse
|
||||
from argparse import ArgumentParser, _SubParsersAction
|
||||
|
||||
|
||||
def get_subparser(parser: ArgumentParser, name: str) -> ArgumentParser:
|
||||
for action in parser._actions:
|
||||
if isinstance(action, _SubParsersAction):
|
||||
if name in action.choices:
|
||||
return action.choices[name]
|
||||
raise ValueError(f"Subparser '{name}' not found")
|
||||
|
||||
|
||||
def add_common_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
@@ -149,71 +140,6 @@ def add_headless_specific_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_cli_parser() -> argparse.ArgumentParser:
|
||||
"""Create argument parser for CLI mode with simplified argument set."""
|
||||
# Create a description with welcome message explaining available commands
|
||||
description = (
|
||||
'Welcome to OpenHands: Code Less, Make More\n\n'
|
||||
'OpenHands supports two main commands:\n'
|
||||
' serve - Launch the OpenHands GUI server (web interface)\n'
|
||||
' cli - Run OpenHands in CLI mode (terminal interface)\n\n'
|
||||
'Running "openhands" without a command is the same as "openhands cli"'
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
prog='openhands',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter, # Preserve formatting in description
|
||||
epilog='For more information about a command, run: openhands COMMAND --help',
|
||||
)
|
||||
|
||||
# Create subparsers
|
||||
subparsers = parser.add_subparsers(
|
||||
dest='command',
|
||||
title='commands',
|
||||
description='OpenHands supports two main commands:',
|
||||
metavar='COMMAND',
|
||||
)
|
||||
|
||||
# Add 'serve' subcommand
|
||||
serve_parser = subparsers.add_parser(
|
||||
'serve', help='Launch the OpenHands GUI server using Docker (web interface)'
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
'--mount-cwd',
|
||||
help='Mount the current working directory into the GUI server container',
|
||||
action='store_true',
|
||||
default=False,
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
'--gpu',
|
||||
help='Enable GPU support by mounting all GPUs into the Docker container via nvidia-docker',
|
||||
action='store_true',
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Add 'cli' subcommand - import all the existing CLI arguments
|
||||
cli_parser = subparsers.add_parser(
|
||||
'cli', help='Run OpenHands in CLI mode (terminal interface)'
|
||||
)
|
||||
add_common_arguments(cli_parser)
|
||||
|
||||
cli_parser.add_argument(
|
||||
'--override-cli-mode',
|
||||
help='Override the default settings for CLI mode',
|
||||
type=bool,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--conversation',
|
||||
help='The conversation id to continue',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_headless_parser() -> argparse.ArgumentParser:
|
||||
"""Create argument parser for headless mode with full argument set."""
|
||||
parser = argparse.ArgumentParser(description='Run the agent via CLI')
|
||||
|
||||
@@ -23,9 +23,7 @@ from openhands.core import logger
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.arg_utils import get_headless_parser
|
||||
from openhands.core.config.condenser_config import (
|
||||
CondenserConfig,
|
||||
condenser_config_from_toml_section,
|
||||
create_condenser_config,
|
||||
)
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.kubernetes_config import KubernetesConfig
|
||||
@@ -37,7 +35,6 @@ from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
JWT_SECRET = '.jwt_secret'
|
||||
load_dotenv()
|
||||
@@ -628,118 +625,6 @@ def get_llms_for_routing_config(toml_file: str = 'config.toml') -> dict[str, LLM
|
||||
return llms_for_routing
|
||||
|
||||
|
||||
def get_condenser_config_arg(
|
||||
condenser_config_arg: str, toml_file: str = 'config.toml'
|
||||
) -> CondenserConfig | None:
|
||||
"""Get a group of condenser settings from the config file by name.
|
||||
|
||||
A group in config.toml can look like this:
|
||||
|
||||
```
|
||||
[condenser.my_summarizer]
|
||||
type = 'llm'
|
||||
llm_config = 'gpt-4o' # References [llm.gpt-4o]
|
||||
max_size = 50
|
||||
...
|
||||
```
|
||||
|
||||
The user-defined group name, like "my_summarizer", is the argument to this function.
|
||||
The function will load the CondenserConfig object with the settings of this group,
|
||||
from the config file.
|
||||
|
||||
Note that the group must be under the "condenser" group, or in other words,
|
||||
the group name must start with "condenser.".
|
||||
|
||||
Args:
|
||||
condenser_config_arg: The group of condenser settings to get from the config.toml file.
|
||||
toml_file: Path to the configuration file to read from. Defaults to 'config.toml'.
|
||||
|
||||
Returns:
|
||||
CondenserConfig: The CondenserConfig object with the settings from the config file, or None if not found/error.
|
||||
"""
|
||||
# keep only the name, just in case
|
||||
condenser_config_arg = condenser_config_arg.strip('[]')
|
||||
|
||||
# truncate the prefix, just in case
|
||||
if condenser_config_arg.startswith('condenser.'):
|
||||
condenser_config_arg = condenser_config_arg[10:]
|
||||
|
||||
logger.openhands_logger.debug(
|
||||
f'Loading condenser config [{condenser_config_arg}] from {toml_file}'
|
||||
)
|
||||
|
||||
# load the toml file
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError as e:
|
||||
logger.openhands_logger.info(f'Config file not found: {toml_file}. Error: {e}')
|
||||
return None
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Cannot parse condenser group [{condenser_config_arg}] from {toml_file}. Exception: {e}'
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if the condenser section and the specific config exist
|
||||
if (
|
||||
'condenser' not in toml_config
|
||||
or condenser_config_arg not in toml_config['condenser']
|
||||
):
|
||||
logger.openhands_logger.error(
|
||||
f'Condenser config section [condenser.{condenser_config_arg}] not found in {toml_file}'
|
||||
)
|
||||
return None
|
||||
|
||||
condenser_data = toml_config['condenser'][
|
||||
condenser_config_arg
|
||||
].copy() # Use copy to modify
|
||||
|
||||
# Determine the type and handle potential LLM dependency
|
||||
condenser_type = condenser_data.get('type')
|
||||
if not condenser_type:
|
||||
logger.openhands_logger.error(
|
||||
f'Missing "type" field in [condenser.{condenser_config_arg}] section of {toml_file}'
|
||||
)
|
||||
return None
|
||||
|
||||
# Handle LLM config reference if needed, using get_llm_config_arg
|
||||
if (
|
||||
condenser_type in ('llm', 'llm_attention', 'structured')
|
||||
and 'llm_config' in condenser_data
|
||||
and isinstance(condenser_data['llm_config'], str)
|
||||
):
|
||||
llm_config_name = condenser_data['llm_config']
|
||||
logger.openhands_logger.debug(
|
||||
f'Condenser [{condenser_config_arg}] requires LLM config [{llm_config_name}]. Loading it...'
|
||||
)
|
||||
# Use the existing function to load the specific LLM config
|
||||
referenced_llm_config = get_llm_config_arg(llm_config_name, toml_file=toml_file)
|
||||
|
||||
if referenced_llm_config:
|
||||
# Replace the string reference with the actual LLMConfig object
|
||||
condenser_data['llm_config'] = referenced_llm_config
|
||||
else:
|
||||
# get_llm_config_arg already logs the error if not found
|
||||
logger.openhands_logger.error(
|
||||
f"Failed to load required LLM config '{llm_config_name}' for condenser '{condenser_config_arg}'."
|
||||
)
|
||||
return None
|
||||
|
||||
# Create the condenser config instance
|
||||
try:
|
||||
config = create_condenser_config(condenser_type, condenser_data)
|
||||
logger.openhands_logger.info(
|
||||
f'Successfully loaded condenser config [{condenser_config_arg}] from {toml_file}'
|
||||
)
|
||||
return config
|
||||
except (ValidationError, ValueError) as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Invalid condenser configuration for [{condenser_config_arg}]: {e}.'
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_model_routing_config_arg(toml_file: str = 'config.toml') -> ModelRoutingConfig:
|
||||
"""Get the model routing settings from the config file. We only support the default model routing config [model_routing].
|
||||
|
||||
@@ -797,29 +682,6 @@ def parse_arguments() -> argparse.Namespace:
|
||||
return args
|
||||
|
||||
|
||||
def register_custom_agents(config: OpenHandsConfig) -> None:
|
||||
"""Register custom agents from configuration.
|
||||
|
||||
This function is called after configuration is loaded to ensure all custom agents
|
||||
specified in the config are properly imported and registered.
|
||||
"""
|
||||
# Import here to avoid circular dependency
|
||||
from openhands.controller.agent import Agent
|
||||
|
||||
for agent_name, agent_config in config.agents.items():
|
||||
if agent_config.classpath:
|
||||
try:
|
||||
agent_cls = get_impl(Agent, agent_config.classpath)
|
||||
Agent.register(agent_name, agent_cls)
|
||||
logger.openhands_logger.info(
|
||||
f"Registered custom agent '{agent_name}' from {agent_config.classpath}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.openhands_logger.error(
|
||||
f"Failed to register agent '{agent_name}': {e}"
|
||||
)
|
||||
|
||||
|
||||
def load_openhands_config(
|
||||
set_logging_levels: bool = True, config_file: str = 'config.toml'
|
||||
) -> OpenHandsConfig:
|
||||
@@ -833,7 +695,6 @@ def load_openhands_config(
|
||||
load_from_toml(config, config_file)
|
||||
load_from_env(config, os.environ)
|
||||
finalize_config(config)
|
||||
register_custom_agents(config)
|
||||
if set_logging_levels:
|
||||
logger.DEBUG = config.debug
|
||||
logger.DISABLE_COLOR_PRINTING = config.disable_color
|
||||
|
||||
@@ -16,16 +16,6 @@ class AgentError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AgentNoInstructionError(AgentError):
|
||||
def __init__(self, message: str = 'Instruction must be provided') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentEventTypeError(AgentError):
|
||||
def __init__(self, message: str = 'Event must be a dictionary') -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentAlreadyRegisteredError(AgentError):
|
||||
def __init__(self, name: str | None = None) -> None:
|
||||
if name is not None:
|
||||
@@ -49,20 +39,6 @@ class AgentStuckInLoopError(AgentError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# ============================================
|
||||
# Agent Controller Exceptions
|
||||
# ============================================
|
||||
|
||||
|
||||
class TaskInvalidStateError(Exception):
|
||||
def __init__(self, state: str | None = None) -> None:
|
||||
if state is not None:
|
||||
message = f'Invalid state {state}'
|
||||
else:
|
||||
message = 'Invalid state'
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# ============================================
|
||||
# LLM Exceptions
|
||||
# ============================================
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
import asyncio
|
||||
|
||||
from openhands.controller import AgentController
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
|
||||
|
||||
async def run_agent_until_done(
|
||||
controller: AgentController,
|
||||
runtime: Runtime,
|
||||
memory: Memory,
|
||||
end_states: list[AgentState],
|
||||
skip_set_callback: bool = False,
|
||||
) -> None:
|
||||
"""run_agent_until_done takes a controller and a runtime, and will run
|
||||
the agent until it reaches a terminal state.
|
||||
Note that runtime must be connected before being passed in here.
|
||||
"""
|
||||
|
||||
def status_callback(msg_type: str, runtime_status: RuntimeStatus, msg: str) -> None:
|
||||
if msg_type == 'error':
|
||||
logger.error(msg)
|
||||
if controller:
|
||||
controller.state.last_error = msg
|
||||
asyncio.create_task(controller.set_agent_state_to(AgentState.ERROR))
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
if not skip_set_callback:
|
||||
if hasattr(runtime, 'status_callback') and runtime.status_callback:
|
||||
raise ValueError(
|
||||
'Runtime status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
if hasattr(controller, 'status_callback') and controller.status_callback:
|
||||
raise ValueError(
|
||||
'Controller status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
|
||||
runtime.status_callback = status_callback
|
||||
controller.status_callback = status_callback
|
||||
memory.status_callback = status_callback
|
||||
|
||||
while controller.state.agent_state not in end_states:
|
||||
await asyncio.sleep(1)
|
||||
@@ -1,393 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Callable, Protocol
|
||||
|
||||
from openhands.controller.replay import ReplayManager
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
OpenHandsConfig,
|
||||
parse_arguments,
|
||||
setup_config_from_args,
|
||||
)
|
||||
from openhands.core.config.mcp_config import MCPConfig, OpenHandsMCPConfigImpl
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.loop import run_agent_until_done
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.core.setup import (
|
||||
create_agent,
|
||||
create_controller,
|
||||
create_memory,
|
||||
create_runtime,
|
||||
generate_sid,
|
||||
get_provider_tokens,
|
||||
initialize_repository_for_runtime,
|
||||
)
|
||||
from openhands.events import EventSource, EventStreamSubscriber
|
||||
from openhands.events.action import MessageAction, NullAction
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.mcp import add_mcp_tools_to_agent
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
from openhands.utils.utils import create_registry_and_conversation_stats
|
||||
|
||||
|
||||
class FakeUserResponseFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
state: State,
|
||||
encapsulate_solution: bool = False,
|
||||
try_parse: Callable[[Action | None], str] | None = None,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
async def run_controller(
|
||||
config: OpenHandsConfig,
|
||||
initial_user_action: Action,
|
||||
sid: str | None = None,
|
||||
runtime: Runtime | None = None,
|
||||
exit_on_message: bool = False,
|
||||
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
||||
headless_mode: bool = True,
|
||||
memory: Memory | None = None,
|
||||
conversation_instructions: str | None = None,
|
||||
) -> State | None:
|
||||
"""Main coroutine to run the agent controller with task input flexibility.
|
||||
|
||||
It's only used when you launch openhands backend directly via cmdline.
|
||||
|
||||
Args:
|
||||
config: The app config.
|
||||
initial_user_action: An Action object containing initial user input
|
||||
sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
|
||||
Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
|
||||
runtime: (optional) A runtime for the agent to run on.
|
||||
exit_on_message: quit if agent asks for a message from user (optional)
|
||||
fake_user_response_fn: An optional function that receives the current state
|
||||
(could be None) and returns a fake user response.
|
||||
headless_mode: Whether the agent is run in headless mode.
|
||||
|
||||
Returns:
|
||||
The final state of the agent, or None if an error occurred.
|
||||
|
||||
Raises:
|
||||
AssertionError: If initial_user_action is not an Action instance.
|
||||
Exception: Various exceptions may be raised during execution and will be logged.
|
||||
|
||||
Notes:
|
||||
- State persistence: If config.file_store is set, the agent's state will be
|
||||
saved between sessions.
|
||||
- Trajectories: If config.trajectories_path is set, execution history will be
|
||||
saved as JSON for analysis.
|
||||
- Budget control: Execution is limited by config.max_iterations and
|
||||
config.max_budget_per_task.
|
||||
|
||||
Example:
|
||||
>>> config = load_openhands_config()
|
||||
>>> action = MessageAction(content="Write a hello world program")
|
||||
>>> state = await run_controller(config=config, initial_user_action=action)
|
||||
"""
|
||||
sid = sid or generate_sid(config)
|
||||
|
||||
llm_registry, conversation_stats, config = create_registry_and_conversation_stats(
|
||||
config,
|
||||
sid,
|
||||
None,
|
||||
)
|
||||
|
||||
agent = create_agent(config, llm_registry)
|
||||
|
||||
# when the runtime is created, it will be connected and clone the selected repository
|
||||
repo_directory = None
|
||||
if runtime is None:
|
||||
# In itialize repository if needed
|
||||
repo_tokens = get_provider_tokens()
|
||||
runtime = create_runtime(
|
||||
config,
|
||||
llm_registry,
|
||||
sid=sid,
|
||||
headless_mode=headless_mode,
|
||||
agent=agent,
|
||||
git_provider_tokens=repo_tokens,
|
||||
)
|
||||
# Connect to the runtime
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
# Initialize repository if needed
|
||||
if config.sandbox.selected_repo:
|
||||
repo_directory = initialize_repository_for_runtime(
|
||||
runtime,
|
||||
immutable_provider_tokens=repo_tokens,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
# when memory is created, it will load the microagents from the selected repository
|
||||
if memory is None:
|
||||
memory = create_memory(
|
||||
runtime=runtime,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
repo_directory=repo_directory,
|
||||
conversation_instructions=conversation_instructions,
|
||||
working_dir=str(runtime.workspace_root),
|
||||
)
|
||||
|
||||
# Add MCP tools to the agent
|
||||
if agent.config.enable_mcp:
|
||||
# Add OpenHands' MCP server by default
|
||||
default_servers = await OpenHandsMCPConfigImpl.create_default_mcp_server_config(
|
||||
config.mcp_host, config, None
|
||||
)
|
||||
runtime.config.mcp = MCPConfig(
|
||||
mcpServers={**runtime.config.mcp.mcpServers, **default_servers}
|
||||
)
|
||||
|
||||
await add_mcp_tools_to_agent(agent, runtime, memory)
|
||||
|
||||
replay_events: list[Event] | None = None
|
||||
if config.replay_trajectory_path:
|
||||
logger.info('Trajectory replay is enabled')
|
||||
assert isinstance(initial_user_action, NullAction)
|
||||
replay_events, initial_user_action = load_replay_log(
|
||||
config.replay_trajectory_path
|
||||
)
|
||||
|
||||
controller, initial_state = create_controller(
|
||||
agent, runtime, config, conversation_stats, replay_events=replay_events
|
||||
)
|
||||
|
||||
assert isinstance(initial_user_action, Action), (
|
||||
f'initial user actions must be an Action, got {type(initial_user_action)}'
|
||||
)
|
||||
logger.debug(
|
||||
f'Agent Controller Initialized: Running agent {agent.name}, model '
|
||||
f'{agent.llm.config.model}, with actions: {initial_user_action}'
|
||||
)
|
||||
|
||||
# Set up asyncio-safe signal handler for graceful shutdown
|
||||
sigint_count = 0
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
"""Handle SIGINT signals for graceful shutdown."""
|
||||
nonlocal sigint_count
|
||||
sigint_count += 1
|
||||
|
||||
if sigint_count == 1:
|
||||
logger.info('Received SIGINT (Ctrl+C). Initiating graceful shutdown...')
|
||||
logger.info('Press Ctrl+C again to force immediate exit.')
|
||||
shutdown_event.set()
|
||||
else:
|
||||
logger.info('Received second SIGINT. Forcing immediate exit...')
|
||||
sys.exit(1)
|
||||
|
||||
# Register the asyncio signal handler (safer for async contexts)
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
|
||||
# start event is a MessageAction with the task, either resumed or new
|
||||
if initial_state is not None and initial_state.last_error:
|
||||
# we're resuming the previous session
|
||||
event_stream.add_event(
|
||||
MessageAction(
|
||||
content=(
|
||||
"Let's get back on track. If you experienced errors before, do "
|
||||
'NOT resume your task. Ask me about it.'
|
||||
),
|
||||
),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
# init with the provided actions
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
|
||||
def on_event(event: Event) -> None:
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.AWAITING_USER_INPUT:
|
||||
if exit_on_message:
|
||||
message = '/exit'
|
||||
elif fake_user_response_fn is None:
|
||||
message = read_input(config.cli_multiline_input)
|
||||
else:
|
||||
message = fake_user_response_fn(controller.get_state())
|
||||
action = MessageAction(content=message)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)
|
||||
|
||||
end_states = [
|
||||
AgentState.FINISHED,
|
||||
AgentState.REJECTED,
|
||||
AgentState.ERROR,
|
||||
AgentState.PAUSED,
|
||||
AgentState.STOPPED,
|
||||
]
|
||||
|
||||
try:
|
||||
# Create a task for the main agent loop
|
||||
agent_task = asyncio.create_task(
|
||||
run_agent_until_done(controller, runtime, memory, end_states)
|
||||
)
|
||||
|
||||
# Wait for either the agent to complete or shutdown signal
|
||||
done, pending = await asyncio.wait(
|
||||
[agent_task, asyncio.create_task(shutdown_event.wait())],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# Cancel any pending tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all cancelled tasks to complete in parallel
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
# Check if shutdown was requested
|
||||
if shutdown_event.is_set():
|
||||
logger.info('Graceful shutdown requested.')
|
||||
|
||||
# Perform graceful cleanup sequence
|
||||
try:
|
||||
# 1. Stop the agent controller first to prevent new LLM calls
|
||||
logger.debug('Stopping agent controller...')
|
||||
await controller.close()
|
||||
|
||||
# 2. Stop the EventStream to prevent new events from being processed
|
||||
logger.debug('Stopping EventStream...')
|
||||
event_stream.close()
|
||||
|
||||
# 3. Give time for in-flight operations to complete before closing runtime
|
||||
logger.debug('Waiting for in-flight operations to complete...')
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 4. Close the runtime to avoid bash session interruption errors
|
||||
logger.debug('Closing runtime...')
|
||||
runtime.close()
|
||||
|
||||
# 5. Give a brief moment for final cleanup to complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Error during graceful cleanup: {e}')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Exception in main loop: {e}')
|
||||
|
||||
# save session when we're about to close
|
||||
if config.file_store is not None and config.file_store != 'memory':
|
||||
end_state = controller.get_state()
|
||||
# NOTE: the saved state does not include delegates events
|
||||
end_state.save_to_session(
|
||||
event_stream.sid, event_stream.file_store, event_stream.user_id
|
||||
)
|
||||
|
||||
await controller.close(set_stop_state=False)
|
||||
|
||||
state = controller.get_state()
|
||||
|
||||
# save trajectories if applicable
|
||||
if config.save_trajectory_path is not None:
|
||||
# if save_trajectory_path is a folder, use session id as file name
|
||||
if os.path.isdir(config.save_trajectory_path):
|
||||
file_path = os.path.join(config.save_trajectory_path, sid + '.json')
|
||||
else:
|
||||
file_path = config.save_trajectory_path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
histories = controller.get_trajectory(config.save_screenshots_in_trajectory)
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(histories, f, indent=4)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def auto_continue_response(
|
||||
state: State,
|
||||
encapsulate_solution: bool = False,
|
||||
try_parse: Callable[[Action | None], str] | None = None,
|
||||
) -> str:
|
||||
"""Default function to generate user responses.
|
||||
Tell the agent to proceed without asking for more input, or finish the interaction.
|
||||
"""
|
||||
message = (
|
||||
'Please continue on whatever approach you think is suitable.\n'
|
||||
'If you think you have solved the task, please finish the interaction.\n'
|
||||
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN RESPONSE.\n'
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
|
||||
"""Load trajectory from given path, serialize it to a list of events, and return
|
||||
two things:
|
||||
1) A list of events except the first action
|
||||
2) First action (user message, a.k.a. initial task)
|
||||
"""
|
||||
try:
|
||||
path = Path(trajectory_path).resolve()
|
||||
|
||||
if not path.exists():
|
||||
raise ValueError(f'Trajectory file not found: {path}')
|
||||
|
||||
if not path.is_file():
|
||||
raise ValueError(f'Trajectory path is a directory, not a file: {path}')
|
||||
|
||||
with open(path, 'r', encoding='utf-8') as file:
|
||||
events = ReplayManager.get_replay_events(json.load(file))
|
||||
assert isinstance(events[0], MessageAction)
|
||||
return events[1:], events[0]
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
config: OpenHandsConfig = setup_config_from_args(args)
|
||||
|
||||
# Read task from file, CLI args, or stdin
|
||||
task_str = read_task(args, config.cli_multiline_input)
|
||||
|
||||
initial_user_action: Action = NullAction()
|
||||
if config.replay_trajectory_path:
|
||||
if task_str:
|
||||
raise ValueError(
|
||||
'User-specified task is not supported under trajectory replay mode'
|
||||
)
|
||||
else:
|
||||
if not task_str:
|
||||
raise ValueError('No task provided. Please specify a task through -t, -f.')
|
||||
|
||||
# Create actual initial user action
|
||||
initial_user_action = MessageAction(content=task_str)
|
||||
|
||||
# Set session name
|
||||
session_name = args.name
|
||||
sid = generate_sid(config, session_name)
|
||||
|
||||
asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=initial_user_action,
|
||||
sid=sid,
|
||||
fake_user_response_fn=None
|
||||
if args.no_auto_continue
|
||||
else auto_continue_response,
|
||||
)
|
||||
)
|
||||
@@ -8,93 +8,23 @@
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
OpenHandsConfig,
|
||||
)
|
||||
from openhands.core.config.config_utils import DEFAULT_WORKSPACE_MOUNT_PATH_IN_SANDBOX
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.event import Event
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderToken,
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroagent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
|
||||
def create_runtime(
|
||||
config: OpenHandsConfig,
|
||||
llm_registry: LLMRegistry | None = None,
|
||||
sid: str | None = None,
|
||||
headless_mode: bool = True,
|
||||
agent: Agent | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
) -> Runtime:
|
||||
"""Create a runtime for the agent to run on.
|
||||
|
||||
Args:
|
||||
config: The app config.
|
||||
sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
|
||||
Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
|
||||
headless_mode: Whether the agent is run in headless mode. `create_runtime` is typically called within evaluation scripts,
|
||||
where we don't want to have the VSCode UI open, so it defaults to True.
|
||||
agent: (optional) The agent instance to use for configuring the runtime.
|
||||
|
||||
Returns:
|
||||
The created Runtime instance (not yet connected or initialized).
|
||||
"""
|
||||
# if sid is provided on the command line, use it as the name of the event stream
|
||||
# otherwise generate it on the basis of the configured jwt_secret
|
||||
# we can do this better, this is just so that the sid is retrieved when we want to restore the session
|
||||
session_id = sid or generate_sid(config)
|
||||
|
||||
# set up the event stream
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
event_stream = EventStream(session_id, file_store)
|
||||
|
||||
# agent class
|
||||
if agent:
|
||||
agent_cls = type(agent)
|
||||
else:
|
||||
agent_cls = Agent.get_cls(config.default_agent)
|
||||
|
||||
# runtime and tools
|
||||
runtime_cls = get_runtime_cls(config.runtime)
|
||||
logger.debug(f'Initializing runtime: {runtime_cls.__name__}')
|
||||
runtime: Runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid=session_id,
|
||||
plugins=agent_cls.sandbox_plugins,
|
||||
headless_mode=headless_mode,
|
||||
llm_registry=llm_registry or LLMRegistry(config),
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
|
||||
# Log the plugins that have been registered with the runtime for debugging purposes
|
||||
logger.debug(
|
||||
f'Runtime created with plugins: {[plugin.name for plugin in runtime.plugins]}'
|
||||
)
|
||||
|
||||
return runtime
|
||||
|
||||
|
||||
def get_provider_tokens():
|
||||
"""Retrieve provider tokens from environment variables and return them as a dictionary.
|
||||
|
||||
@@ -184,96 +114,6 @@ def initialize_repository_for_runtime(
|
||||
return repo_directory
|
||||
|
||||
|
||||
def create_memory(
|
||||
runtime: Runtime,
|
||||
event_stream: EventStream,
|
||||
sid: str,
|
||||
selected_repository: str | None = None,
|
||||
repo_directory: str | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
conversation_instructions: str | None = None,
|
||||
working_dir: str = DEFAULT_WORKSPACE_MOUNT_PATH_IN_SANDBOX,
|
||||
) -> Memory:
|
||||
"""Create a memory for the agent to use.
|
||||
|
||||
Args:
|
||||
runtime: The runtime to use.
|
||||
event_stream: The event stream it will subscribe to.
|
||||
sid: The session id.
|
||||
selected_repository: The repository to clone and start with, if any.
|
||||
repo_directory: The repository directory, if any.
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
conversation_instructions: Optional instructions that are passed to the agent
|
||||
"""
|
||||
memory = Memory(
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
status_callback=status_callback,
|
||||
)
|
||||
|
||||
memory.set_conversation_instructions(conversation_instructions)
|
||||
|
||||
if runtime:
|
||||
# sets available hosts
|
||||
memory.set_runtime_info(runtime, {}, working_dir)
|
||||
|
||||
# loads microagents from repo/.openhands/microagents
|
||||
microagents: list[BaseMicroagent] = runtime.get_microagents_from_selected_repo(
|
||||
selected_repository
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
|
||||
if selected_repository and repo_directory:
|
||||
memory.set_repository_info(selected_repository, repo_directory)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
def create_agent(config: OpenHandsConfig, llm_registry: LLMRegistry) -> Agent:
|
||||
agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
|
||||
agent_config = config.get_agent_config(config.default_agent)
|
||||
# Pass the runtime information from the main config to the agent config
|
||||
agent_config.runtime = config.runtime
|
||||
config.get_llm_config_from_agent(config.default_agent)
|
||||
agent = agent_cls(config=agent_config, llm_registry=llm_registry)
|
||||
return agent
|
||||
|
||||
|
||||
def create_controller(
|
||||
agent: Agent,
|
||||
runtime: Runtime,
|
||||
config: OpenHandsConfig,
|
||||
conversation_stats: ConversationStats,
|
||||
headless_mode: bool = True,
|
||||
replay_events: list[Event] | None = None,
|
||||
) -> tuple[AgentController, State | None]:
|
||||
event_stream = runtime.event_stream
|
||||
initial_state = None
|
||||
try:
|
||||
logger.debug(
|
||||
f'Trying to restore agent state from session {event_stream.sid} if available'
|
||||
)
|
||||
initial_state = State.restore_from_session(
|
||||
event_stream.sid, event_stream.file_store
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f'Cannot restore agent state: {e}')
|
||||
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
conversation_stats=conversation_stats,
|
||||
iteration_delta=config.max_iterations,
|
||||
budget_per_task_delta=config.max_budget_per_task,
|
||||
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
||||
event_stream=event_stream,
|
||||
initial_state=initial_state,
|
||||
headless_mode=headless_mode,
|
||||
confirmation_mode=config.security.confirmation_mode,
|
||||
replay_events=replay_events,
|
||||
)
|
||||
return (controller, initial_state)
|
||||
|
||||
|
||||
def generate_sid(config: OpenHandsConfig, session_name: str | None = None) -> str:
|
||||
"""Generate a session id based on the session name and the jwt secret.
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.event_store import EventStore
|
||||
|
||||
|
||||
class AsyncEventStoreWrapper:
|
||||
def __init__(self, event_store: EventStore, *args: Any, **kwargs: Any) -> None:
|
||||
self.event_store = event_store
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[Event]:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an async generator that yields events
|
||||
for event in self.event_store.search_events(*self.args, **self.kwargs):
|
||||
# Run the blocking search_events() in a thread pool
|
||||
def get_event(e: Event = event) -> Event:
|
||||
return e
|
||||
|
||||
yield await loop.run_in_executor(None, get_event)
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -11,7 +12,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.io import json
|
||||
from openhands.storage import FileStore
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_dir,
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
|
||||
|
||||
def get_pairs_from_events(events: list[Event]) -> list[tuple[Action, Observation]]:
|
||||
"""Return the history as a list of tuples (action, observation).
|
||||
|
||||
This function is a compatibility function for evals reading and visualization working with old histories.
|
||||
"""
|
||||
tuples: list[tuple[Action, Observation]] = []
|
||||
action_map: dict[int, Action] = {}
|
||||
observation_map: dict[int, Observation] = {}
|
||||
|
||||
# runnable actions are set as cause of observations
|
||||
# (MessageAction, NullObservation) for source=USER
|
||||
# (MessageAction, NullObservation) for source=AGENT
|
||||
# (other_action?, NullObservation)
|
||||
# (NullAction, CmdOutputObservation) background CmdOutputObservations
|
||||
|
||||
for event in events:
|
||||
if event.id is None or event.id == -1:
|
||||
logger.debug(f'Event {event} has no ID')
|
||||
|
||||
if isinstance(event, Action):
|
||||
action_map[event.id] = event
|
||||
|
||||
if isinstance(event, Observation):
|
||||
if event.cause is None or event.cause == -1:
|
||||
logger.debug(f'Observation {event} has no cause')
|
||||
|
||||
if event.cause is None:
|
||||
# runnable actions are set as cause of observations
|
||||
# NullObservations have no cause
|
||||
continue
|
||||
|
||||
observation_map[event.cause] = event
|
||||
|
||||
for action_id, action in action_map.items():
|
||||
observation = observation_map.get(action_id)
|
||||
if observation:
|
||||
# observation with a cause
|
||||
tuples.append((action, observation))
|
||||
else:
|
||||
tuples.append((action, NullObservation('')))
|
||||
|
||||
for cause_id, observation in observation_map.items():
|
||||
if cause_id not in action_map:
|
||||
if isinstance(observation, NullObservation):
|
||||
continue
|
||||
if not isinstance(observation, CmdOutputObservation):
|
||||
logger.debug(f'Observation {observation} has no cause')
|
||||
tuples.append((NullAction(), observation))
|
||||
|
||||
return tuples.copy()
|
||||
@@ -1,9 +1,7 @@
|
||||
"""Feature operations for Azure DevOps integration (microagents, suggested tasks, user)."""
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.azure_devops.service.base import AzureDevOpsMixinBase
|
||||
from openhands.integrations.service_types import (
|
||||
MicroagentContentResponse,
|
||||
ProviderType,
|
||||
RequestMethod,
|
||||
SuggestedTask,
|
||||
@@ -139,85 +137,3 @@ class AzureDevOpsFeaturesMixin(AzureDevOpsMixinBase):
|
||||
continue
|
||||
|
||||
return tasks
|
||||
|
||||
async def _get_cursorrules_url(self, repository: str) -> str:
|
||||
"""Get the URL for checking .cursorrules file in Azure DevOps."""
|
||||
org, project, repo = self._parse_repository(repository)
|
||||
# URL-encode components to handle spaces and special characters
|
||||
org_enc = self._encode_url_component(org)
|
||||
project_enc = self._encode_url_component(project)
|
||||
repo_enc = self._encode_url_component(repo)
|
||||
return f'{self.base_url}/{org_enc}/{project_enc}/_apis/git/repositories/{repo_enc}/items?path=/.cursorrules&api-version=7.1'
|
||||
|
||||
async def _get_microagents_directory_url(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> str:
|
||||
"""Get the URL for checking microagents directory in Azure DevOps.
|
||||
|
||||
Note: For org-level microagents (e.g., 'org/.openhands'), Azure DevOps doesn't support
|
||||
this concept, so we raise ValueError to let the caller fall back to other providers.
|
||||
"""
|
||||
parts = repository.split('/')
|
||||
if len(parts) < 3:
|
||||
# Azure DevOps doesn't support org-level configs, only full repo paths
|
||||
raise ValueError(
|
||||
f'Invalid repository format: {repository}. Expected format: organization/project/repo'
|
||||
)
|
||||
org, project, repo = parts[0], parts[1], parts[2]
|
||||
# URL-encode components to handle spaces and special characters
|
||||
org_enc = self._encode_url_component(org)
|
||||
project_enc = self._encode_url_component(project)
|
||||
repo_enc = self._encode_url_component(repo)
|
||||
return f'{self.base_url}/{org_enc}/{project_enc}/_apis/git/repositories/{repo_enc}/items?path=/{microagents_path}&recursionLevel=OneLevel&api-version=7.1'
|
||||
|
||||
def _get_microagents_directory_params(self, microagents_path: str) -> dict | None:
|
||||
"""Get parameters for the microagents directory request. Return None if no parameters needed."""
|
||||
return None
|
||||
|
||||
def _is_valid_microagent_file(self, item: dict) -> bool:
|
||||
"""Check if an item represents a valid microagent file in Azure DevOps."""
|
||||
return (
|
||||
not item.get('isFolder', False)
|
||||
and item.get('path', '').endswith('.md')
|
||||
and not item.get('path', '').endswith('README.md')
|
||||
)
|
||||
|
||||
def _get_file_name_from_item(self, item: dict) -> str:
|
||||
"""Extract file name from directory item in Azure DevOps."""
|
||||
path = item.get('path', '')
|
||||
return path.split('/')[-1] if path else ''
|
||||
|
||||
def _get_file_path_from_item(self, item: dict, microagents_path: str) -> str:
|
||||
"""Extract file path from directory item in Azure DevOps."""
|
||||
return item.get('path', '').lstrip('/')
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Get content of a specific microagent file.
|
||||
|
||||
Args:
|
||||
repository: Repository name in Azure DevOps format 'org/project/repo'
|
||||
file_path: Path to the microagent file
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
"""
|
||||
org, project, repo = self._parse_repository(repository)
|
||||
# URL-encode components to handle spaces and special characters
|
||||
org_enc = self._encode_url_component(org)
|
||||
project_enc = self._encode_url_component(project)
|
||||
repo_enc = self._encode_url_component(repo)
|
||||
url = f'{self.base_url}/{org_enc}/{project_enc}/_apis/git/repositories/{repo_enc}/items?path={file_path}&api-version=7.1'
|
||||
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
content = (
|
||||
response if isinstance(response, str) else response.get('content', '')
|
||||
)
|
||||
|
||||
# Parse the content using the base class method
|
||||
return self._parse_microagent_content(content, file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to fetch microagent content from {file_path}: {e}')
|
||||
raise
|
||||
|
||||
@@ -4,7 +4,6 @@ from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.bitbucket.service import (
|
||||
BitBucketBranchesMixin,
|
||||
BitBucketFeaturesMixin,
|
||||
BitBucketPRsMixin,
|
||||
BitBucketReposMixin,
|
||||
)
|
||||
@@ -20,7 +19,6 @@ class BitBucketService(
|
||||
BitBucketReposMixin,
|
||||
BitBucketBranchesMixin,
|
||||
BitBucketPRsMixin,
|
||||
BitBucketFeaturesMixin,
|
||||
GitService,
|
||||
InstallationsService,
|
||||
):
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from .base import BitBucketMixinBase
|
||||
from .branches import BitBucketBranchesMixin
|
||||
from .features import BitBucketFeaturesMixin
|
||||
from .prs import BitBucketPRsMixin
|
||||
from .repos import BitBucketReposMixin
|
||||
|
||||
__all__ = [
|
||||
'BitBucketMixinBase',
|
||||
'BitBucketBranchesMixin',
|
||||
'BitBucketFeaturesMixin',
|
||||
'BitBucketPRsMixin',
|
||||
'BitBucketReposMixin',
|
||||
]
|
||||
|
||||
@@ -12,7 +12,6 @@ from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
Repository,
|
||||
RequestMethod,
|
||||
ResourceNotFoundError,
|
||||
User,
|
||||
)
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
@@ -236,47 +235,3 @@ class BitBucketMixinBase(BaseGitService, HTTPClient):
|
||||
url = f'{self.BASE_URL}/repositories/{repository}'
|
||||
data, _ = await self._make_request(url)
|
||||
return self._parse_repository(data)
|
||||
|
||||
async def _get_cursorrules_url(self, repository: str) -> str:
|
||||
"""Get the URL for checking .cursorrules file."""
|
||||
# Get repository details to get the main branch
|
||||
repo_details = await self.get_repository_details_from_repo_name(repository)
|
||||
if not repo_details.main_branch:
|
||||
raise ResourceNotFoundError(
|
||||
f'Main branch not found for repository {repository}. '
|
||||
f'This repository may be empty or have no default branch configured.'
|
||||
)
|
||||
return f'{self.BASE_URL}/repositories/{repository}/src/{repo_details.main_branch}/.cursorrules'
|
||||
|
||||
async def _get_microagents_directory_url(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> str:
|
||||
"""Get the URL for checking microagents directory."""
|
||||
# Get repository details to get the main branch
|
||||
repo_details = await self.get_repository_details_from_repo_name(repository)
|
||||
if not repo_details.main_branch:
|
||||
raise ResourceNotFoundError(
|
||||
f'Main branch not found for repository {repository}. '
|
||||
f'This repository may be empty or have no default branch configured.'
|
||||
)
|
||||
return f'{self.BASE_URL}/repositories/{repository}/src/{repo_details.main_branch}/{microagents_path}'
|
||||
|
||||
def _get_microagents_directory_params(self, microagents_path: str) -> dict | None:
|
||||
"""Get parameters for the microagents directory request. Return None if no parameters needed."""
|
||||
return None
|
||||
|
||||
def _is_valid_microagent_file(self, item: dict) -> bool:
|
||||
"""Check if an item represents a valid microagent file."""
|
||||
return (
|
||||
item['type'] == 'commit_file'
|
||||
and item['path'].endswith('.md')
|
||||
and not item['path'].endswith('README.md')
|
||||
)
|
||||
|
||||
def _get_file_name_from_item(self, item: dict) -> str:
|
||||
"""Extract file name from directory item."""
|
||||
return item['path'].split('/')[-1]
|
||||
|
||||
def _get_file_path_from_item(self, item: dict, microagents_path: str) -> str:
|
||||
"""Extract file path from directory item."""
|
||||
return item['path']
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.bitbucket.service.base import BitBucketMixinBase
|
||||
from openhands.integrations.service_types import ResourceNotFoundError
|
||||
from openhands.microagent.types import MicroagentContentResponse
|
||||
|
||||
|
||||
class BitBucketFeaturesMixin(BitBucketMixinBase):
|
||||
"""
|
||||
Mixin for BitBucket feature operations (microagents, cursor rules, etc.)
|
||||
"""
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Fetch individual file content from Bitbucket repository.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format 'workspace/repo_slug'
|
||||
file_path: Path to the file within the repository
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be fetched or doesn't exist
|
||||
"""
|
||||
# Step 1: Get repository details using existing method
|
||||
repo_details = await self.get_repository_details_from_repo_name(repository)
|
||||
|
||||
if not repo_details.main_branch:
|
||||
logger.warning(
|
||||
f'No main branch found in repository info for {repository}. '
|
||||
f'Repository response: mainbranch field missing'
|
||||
)
|
||||
raise ResourceNotFoundError(
|
||||
f'Main branch not found for repository {repository}. '
|
||||
f'This repository may be empty or have no default branch configured.'
|
||||
)
|
||||
|
||||
# Step 2: Get file content using the main branch
|
||||
file_url = f'{self.BASE_URL}/repositories/{repository}/src/{repo_details.main_branch}/{file_path}'
|
||||
response, _ = await self._make_request(file_url)
|
||||
|
||||
# Parse the content to extract triggers from frontmatter
|
||||
return self._parse_microagent_content(response, file_path)
|
||||
@@ -4,7 +4,6 @@ from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.bitbucket_data_center.service import (
|
||||
BitbucketDCBranchesMixin,
|
||||
BitbucketDCFeaturesMixin,
|
||||
BitbucketDCPRsMixin,
|
||||
BitbucketDCReposMixin,
|
||||
BitbucketDCResolverMixin,
|
||||
@@ -20,7 +19,6 @@ from openhands.utils.import_utils import get_impl
|
||||
class BitbucketDCService(
|
||||
BitbucketDCResolverMixin,
|
||||
BitbucketDCBranchesMixin,
|
||||
BitbucketDCFeaturesMixin,
|
||||
BitbucketDCPRsMixin,
|
||||
BitbucketDCReposMixin,
|
||||
GitService,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .base import BitbucketDCMixinBase
|
||||
from .branches import BitbucketDCBranchesMixin
|
||||
from .features import BitbucketDCFeaturesMixin
|
||||
from .prs import BitbucketDCPRsMixin
|
||||
from .repos import BitbucketDCReposMixin
|
||||
from .resolver import BitbucketDCResolverMixin
|
||||
@@ -8,7 +7,6 @@ from .resolver import BitbucketDCResolverMixin
|
||||
__all__ = [
|
||||
'BitbucketDCMixinBase',
|
||||
'BitbucketDCBranchesMixin',
|
||||
'BitbucketDCFeaturesMixin',
|
||||
'BitbucketDCPRsMixin',
|
||||
'BitbucketDCReposMixin',
|
||||
'BitbucketDCResolverMixin',
|
||||
|
||||
@@ -13,7 +13,6 @@ from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
Repository,
|
||||
RequestMethod,
|
||||
ResourceNotFoundError,
|
||||
User,
|
||||
)
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
@@ -282,58 +281,3 @@ class BitbucketDCMixinBase(BaseGitService, HTTPClient):
|
||||
url = self._repo_api_base(owner, repo)
|
||||
data, _ = await self._make_request(url)
|
||||
return await self._parse_repository(data, fetch_default_branch=True)
|
||||
|
||||
async def _get_cursorrules_url(self, repository: str) -> str:
|
||||
"""Get the URL for checking .cursorrules file."""
|
||||
# Get repository details to get the main branch
|
||||
repo_details = await self.get_repository_details_from_repo_name(repository)
|
||||
if not repo_details.main_branch:
|
||||
raise ResourceNotFoundError(
|
||||
f'Main branch not found for repository {repository}. '
|
||||
f'This repository may be empty or have no default branch configured.'
|
||||
)
|
||||
owner, repo = self._extract_owner_and_repo(repository)
|
||||
return (
|
||||
f'{self.BASE_URL}/projects/{owner}/repos/{repo}/browse/.cursorrules'
|
||||
f'?at=refs/heads/{repo_details.main_branch}'
|
||||
)
|
||||
|
||||
async def _get_microagents_directory_url(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> str:
|
||||
"""Get the URL for checking microagents directory."""
|
||||
# Get repository details to get the main branch
|
||||
repo_details = await self.get_repository_details_from_repo_name(repository)
|
||||
if not repo_details.main_branch:
|
||||
raise ResourceNotFoundError(
|
||||
f'Main branch not found for repository {repository}. '
|
||||
f'This repository may be empty or have no default branch configured.'
|
||||
)
|
||||
|
||||
owner, repo = self._extract_owner_and_repo(repository)
|
||||
return (
|
||||
f'{self.BASE_URL}/projects/{owner}/repos/{repo}/browse/{microagents_path}'
|
||||
f'?at=refs/heads/{repo_details.main_branch}'
|
||||
)
|
||||
|
||||
def _get_microagents_directory_params(self, microagents_path: str) -> dict | None:
|
||||
"""Get parameters for the microagents directory request. Return None if no parameters needed."""
|
||||
return None
|
||||
|
||||
def _is_valid_microagent_file(self, item: dict) -> bool:
|
||||
"""Check if an item represents a valid microagent file."""
|
||||
file_name = item.get('path', {}).get('name', '')
|
||||
return (
|
||||
item.get('type') == 'FILE'
|
||||
and file_name.endswith('.md')
|
||||
and file_name != 'README.md'
|
||||
)
|
||||
|
||||
def _get_file_name_from_item(self, item: dict) -> str:
|
||||
"""Extract file name from directory item."""
|
||||
return item.get('path', {}).get('name', '')
|
||||
|
||||
def _get_file_path_from_item(self, item: dict, microagents_path: str) -> str:
|
||||
"""Extract file path from directory item."""
|
||||
file_name = self._get_file_name_from_item(item)
|
||||
return f'{microagents_path}/{file_name}'
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.bitbucket_data_center.service.base import (
|
||||
BitbucketDCMixinBase,
|
||||
)
|
||||
from openhands.integrations.service_types import ResourceNotFoundError
|
||||
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
|
||||
|
||||
|
||||
class BitbucketDCFeaturesMixin(BitbucketDCMixinBase):
|
||||
"""
|
||||
Mixin for BitBucket data center feature operations (microagents, cursor rules, etc.)
|
||||
"""
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Fetch individual file content from Bitbucket data center repository.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format 'project/repo_slug'
|
||||
file_path: Path to the file within the repository
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be fetched or doesn't exist
|
||||
"""
|
||||
# Step 1: Get repository details using existing method
|
||||
repo_details = await self.get_repository_details_from_repo_name(repository)
|
||||
|
||||
if not repo_details.main_branch:
|
||||
logger.warning(
|
||||
f'No main branch found in repository info for {repository}. '
|
||||
f'Repository response: mainbranch field missing'
|
||||
)
|
||||
raise ResourceNotFoundError(
|
||||
f'Main branch not found for repository {repository}. '
|
||||
f'This repository may be empty or have no default branch configured.'
|
||||
)
|
||||
|
||||
# Step 2: Get file content using the main branch
|
||||
owner, repo = self._extract_owner_and_repo(repository)
|
||||
repo_base = self._repo_api_base(owner, repo)
|
||||
|
||||
file_url = f'{repo_base}/browse/{file_path}'
|
||||
params = {'at': f'refs/heads/{repo_details.main_branch}'}
|
||||
response, _ = await self._make_request(file_url, params=params)
|
||||
if isinstance(response, dict):
|
||||
lines = response.get('lines')
|
||||
if isinstance(lines, list):
|
||||
content = '\n'.join(
|
||||
line.get('text', '') for line in lines if isinstance(line, dict)
|
||||
)
|
||||
else:
|
||||
content = response.get('content', '')
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
# Parse the content to extract triggers from frontmatter
|
||||
return self._parse_microagent_content(content, file_path)
|
||||
|
||||
async def _process_microagents_directory(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> list[MicroagentResponse]:
|
||||
microagents = []
|
||||
try:
|
||||
directory_url = await self._get_microagents_directory_url(
|
||||
repository, microagents_path
|
||||
)
|
||||
directory_params = self._get_microagents_directory_params(microagents_path)
|
||||
response, _ = await self._make_request(directory_url, directory_params)
|
||||
|
||||
# Bitbucket DC browse endpoint nests items under response['children']['values']
|
||||
items = response.get('children', {}).get('values', [])
|
||||
|
||||
for item in items:
|
||||
if self._is_valid_microagent_file(item):
|
||||
try:
|
||||
file_name = self._get_file_name_from_item(item)
|
||||
file_path = self._get_file_path_from_item(
|
||||
item, microagents_path
|
||||
)
|
||||
microagents.append(
|
||||
self._create_microagent_response(file_name, file_path)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Error processing microagent {item}: {str(e)}')
|
||||
except ResourceNotFoundError:
|
||||
logger.info(
|
||||
f'No microagents directory found in {repository} at {microagents_path}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Error fetching microagents directory: {str(e)}')
|
||||
|
||||
return microagents
|
||||
@@ -1,123 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.forgejo.service.base import ForgejoMixinBase
|
||||
from openhands.integrations.service_types import (
|
||||
MicroagentContentResponse,
|
||||
MicroagentResponse,
|
||||
ProviderType,
|
||||
ResourceNotFoundError,
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.integrations.service_types import SuggestedTask
|
||||
|
||||
|
||||
class ForgejoFeaturesMixin(ForgejoMixinBase):
|
||||
"""Microagent and feature helpers for Forgejo."""
|
||||
|
||||
async def _get_cursorrules_url(self, repository: str) -> str:
|
||||
owner, repo = self._split_repo(repository)
|
||||
return self._build_repo_api_url(owner, repo, 'contents', '.cursorrules')
|
||||
|
||||
async def _get_microagents_directory_url(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> str:
|
||||
owner, repo = self._split_repo(repository)
|
||||
normalized_path = microagents_path.strip('/')
|
||||
return self._build_repo_api_url(owner, repo, 'contents', normalized_path)
|
||||
|
||||
def _get_microagents_directory_params(self, microagents_path: str) -> dict | None:
|
||||
return None
|
||||
|
||||
def _is_valid_microagent_file(self, item: dict[str, Any] | None) -> bool:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
if item.get('type') != 'file':
|
||||
return False
|
||||
name = item.get('name', '')
|
||||
return isinstance(name, str) and (
|
||||
name.endswith('.md') or name.endswith('.cursorrules')
|
||||
)
|
||||
|
||||
def _get_file_name_from_item(self, item: dict[str, Any] | None) -> str:
|
||||
if not isinstance(item, dict):
|
||||
return ''
|
||||
name = item.get('name')
|
||||
return name if isinstance(name, str) else ''
|
||||
|
||||
def _get_file_path_from_item(
|
||||
self, item: dict[str, Any] | None, microagents_path: str
|
||||
) -> str:
|
||||
file_name = self._get_file_name_from_item(item)
|
||||
if not microagents_path:
|
||||
return file_name
|
||||
return f'{microagents_path.strip("/")}/{file_name}'
|
||||
|
||||
async def get_microagents(self, repository: str) -> list[MicroagentResponse]: # type: ignore[override]
|
||||
microagents_path = self._determine_microagents_path(repository)
|
||||
microagents: list[MicroagentResponse] = []
|
||||
|
||||
try:
|
||||
directory_url = await self._get_microagents_directory_url(
|
||||
repository, microagents_path
|
||||
)
|
||||
items, _ = await self._make_request(directory_url)
|
||||
except ResourceNotFoundError:
|
||||
items = []
|
||||
except Exception as exc:
|
||||
# Fail gracefully if the directory cannot be inspected
|
||||
self._log_microagent_warning(repository, str(exc))
|
||||
items = []
|
||||
|
||||
if isinstance(items, list):
|
||||
for item in items:
|
||||
if self._is_valid_microagent_file(item):
|
||||
file_name = self._get_file_name_from_item(item)
|
||||
file_path = self._get_file_path_from_item(item, microagents_path)
|
||||
microagents.append(
|
||||
self._create_microagent_response(file_name, file_path)
|
||||
)
|
||||
|
||||
cursorrules = await self._check_cursorrules_file(repository)
|
||||
if cursorrules:
|
||||
microagents.append(cursorrules)
|
||||
|
||||
return microagents
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse: # type: ignore[override]
|
||||
owner, repo = self._split_repo(repository)
|
||||
normalized_path = file_path.lstrip('/')
|
||||
url = self._build_repo_api_url(owner, repo, 'contents', normalized_path)
|
||||
|
||||
response, _ = await self._make_request(url)
|
||||
content = response.get('content') or ''
|
||||
encoding = (response.get('encoding') or 'base64').lower()
|
||||
|
||||
if encoding == 'base64':
|
||||
try:
|
||||
decoded = base64.b64decode(content).decode('utf-8')
|
||||
except Exception:
|
||||
decoded = ''
|
||||
else:
|
||||
decoded = content
|
||||
|
||||
try:
|
||||
return self._parse_microagent_content(decoded, file_path)
|
||||
except Exception:
|
||||
return MicroagentContentResponse(
|
||||
content=decoded,
|
||||
path=file_path,
|
||||
triggers=[],
|
||||
git_provider=ProviderType.FORGEJO.value,
|
||||
)
|
||||
|
||||
async def get_suggested_tasks(self) -> list[SuggestedTask]: # type: ignore[override]
|
||||
# Suggested tasks are not yet implemented for Forgejo.
|
||||
return []
|
||||
|
||||
def _log_microagent_warning(self, repository: str, message: str) -> None:
|
||||
logger.debug(f'Forgejo microagent scan warning for {repository}: {message}')
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import base64
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.queries import (
|
||||
suggested_task_issue_graphql_query,
|
||||
@@ -7,7 +5,6 @@ from openhands.integrations.github.queries import (
|
||||
)
|
||||
from openhands.integrations.github.service.base import GitHubMixinBase
|
||||
from openhands.integrations.service_types import (
|
||||
MicroagentContentResponse,
|
||||
ProviderType,
|
||||
SuggestedTask,
|
||||
TaskType,
|
||||
@@ -118,60 +115,3 @@ class GitHubFeaturesMixin(GitHubMixinBase):
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
"""
|
||||
Methods specifically for microagent management page
|
||||
"""
|
||||
|
||||
async def _get_cursorrules_url(self, repository: str) -> str:
|
||||
"""Get the URL for checking .cursorrules file."""
|
||||
return f'{self.BASE_URL}/repos/{repository}/contents/.cursorrules'
|
||||
|
||||
async def _get_microagents_directory_url(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> str:
|
||||
"""Get the URL for checking microagents directory."""
|
||||
return f'{self.BASE_URL}/repos/{repository}/contents/{microagents_path}'
|
||||
|
||||
def _is_valid_microagent_file(self, item: dict) -> bool:
|
||||
"""Check if an item represents a valid microagent file."""
|
||||
return (
|
||||
item['type'] == 'file'
|
||||
and item['name'].endswith('.md')
|
||||
and item['name'] != 'README.md'
|
||||
)
|
||||
|
||||
def _get_file_name_from_item(self, item: dict) -> str:
|
||||
"""Extract file name from directory item."""
|
||||
return item['name']
|
||||
|
||||
def _get_file_path_from_item(self, item: dict, microagents_path: str) -> str:
|
||||
"""Extract file path from directory item."""
|
||||
return f'{microagents_path}/{item["name"]}'
|
||||
|
||||
def _get_microagents_directory_params(self, microagents_path: str) -> dict | None:
|
||||
"""Get parameters for the microagents directory request. Return None if no parameters needed."""
|
||||
return None
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Fetch individual file content from GitHub repository.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format 'owner/repo'
|
||||
file_path: Path to the file within the repository
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be fetched or doesn't exist
|
||||
"""
|
||||
file_url = f'{self.BASE_URL}/repos/{repository}/contents/{file_path}'
|
||||
|
||||
file_data, _ = await self._make_request(file_url)
|
||||
file_content = base64.b64decode(file_data['content']).decode('utf-8')
|
||||
|
||||
# Parse the content to extract triggers from frontmatter
|
||||
return self._parse_microagent_content(file_content, file_path)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from openhands.integrations.gitlab.service.base import GitLabMixinBase
|
||||
from openhands.integrations.service_types import (
|
||||
MicroagentContentResponse,
|
||||
ProviderType,
|
||||
RequestMethod,
|
||||
SuggestedTask,
|
||||
@@ -13,40 +12,6 @@ class GitLabFeaturesMixin(GitLabMixinBase):
|
||||
Methods used for custom features in UI driven via GitLab integration
|
||||
"""
|
||||
|
||||
async def _get_cursorrules_url(self, repository: str) -> str:
|
||||
"""Get the URL for checking .cursorrules file."""
|
||||
project_id = self._extract_project_id(repository)
|
||||
return (
|
||||
f'{self.BASE_URL}/projects/{project_id}/repository/files/.cursorrules/raw'
|
||||
)
|
||||
|
||||
async def _get_microagents_directory_url(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> str:
|
||||
"""Get the URL for checking microagents directory."""
|
||||
project_id = self._extract_project_id(repository)
|
||||
return f'{self.BASE_URL}/projects/{project_id}/repository/tree'
|
||||
|
||||
def _get_microagents_directory_params(self, microagents_path: str) -> dict:
|
||||
"""Get parameters for the microagents directory request."""
|
||||
return {'path': microagents_path, 'recursive': 'true'}
|
||||
|
||||
def _is_valid_microagent_file(self, item: dict) -> bool:
|
||||
"""Check if an item represents a valid microagent file."""
|
||||
return (
|
||||
item['type'] == 'blob'
|
||||
and item['name'].endswith('.md')
|
||||
and item['name'] != 'README.md'
|
||||
)
|
||||
|
||||
def _get_file_name_from_item(self, item: dict) -> str:
|
||||
"""Extract file name from directory item."""
|
||||
return item['name']
|
||||
|
||||
def _get_file_path_from_item(self, item: dict, microagents_path: str) -> str:
|
||||
"""Extract file path from directory item."""
|
||||
return item['path']
|
||||
|
||||
async def get_suggested_tasks(self) -> list[SuggestedTask]:
|
||||
"""Get suggested tasks for the authenticated user across all repositories.
|
||||
|
||||
@@ -178,30 +143,3 @@ class GitLabFeaturesMixin(GitLabMixinBase):
|
||||
return tasks
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Fetch individual file content from GitLab repository.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format 'owner/repo' or 'domain/owner/repo'
|
||||
file_path: Path to the file within the repository
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be fetched or doesn't exist
|
||||
"""
|
||||
# Extract project_id from repository name
|
||||
project_id = self._extract_project_id(repository)
|
||||
|
||||
encoded_file_path = file_path.replace('/', '%2F')
|
||||
base_url = f'{self.BASE_URL}/projects/{project_id}'
|
||||
file_url = f'{base_url}/repository/files/{encoded_file_path}/raw'
|
||||
|
||||
response, _ = await self._make_request(file_url)
|
||||
|
||||
# Parse the content to extract triggers from frontmatter
|
||||
return self._parse_microagent_content(response, file_path)
|
||||
|
||||
@@ -33,17 +33,14 @@ from openhands.integrations.service_types import (
|
||||
Branch,
|
||||
GitService,
|
||||
InstallationsService,
|
||||
MicroagentParseError,
|
||||
PaginatedBranchesResponse,
|
||||
ProviderTimeoutError,
|
||||
ProviderType,
|
||||
Repository,
|
||||
ResourceNotFoundError,
|
||||
SuggestedTask,
|
||||
TokenResponse,
|
||||
User,
|
||||
)
|
||||
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
@@ -599,104 +596,6 @@ class ProviderHandler:
|
||||
total_count=0,
|
||||
)
|
||||
|
||||
async def get_microagents(self, repository: str) -> list[MicroagentResponse]:
|
||||
"""Get microagents from a repository using the appropriate service.
|
||||
|
||||
Args:
|
||||
repository: Repository name in the format 'owner/repo'
|
||||
|
||||
Returns:
|
||||
List of microagents found in the repository
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
# Try all available providers in order
|
||||
errors = []
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self.get_service(provider)
|
||||
result = await service.get_microagents(repository)
|
||||
# Only return early if we got a non-empty result
|
||||
if result:
|
||||
return result
|
||||
# If we got an empty array, continue checking other providers
|
||||
logger.debug(
|
||||
f'No microagents found on {provider} for {repository}, trying other providers'
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(f'{provider.value}: {str(e)}')
|
||||
logger.warning(
|
||||
f'Error fetching microagents from {provider} for {repository}: {e}'
|
||||
)
|
||||
|
||||
# If all providers failed or returned empty results, return empty array
|
||||
if errors:
|
||||
logger.error(
|
||||
f'Failed to fetch microagents for {repository} with all available providers. Errors: {"; ".join(errors)}'
|
||||
)
|
||||
raise AuthenticationError(f'Unable to fetch microagents for {repository}')
|
||||
|
||||
# All providers returned empty arrays
|
||||
return []
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Get content of a specific microagent file from a repository.
|
||||
|
||||
Args:
|
||||
repository: Repository name in the format 'owner/repo'
|
||||
file_path: Path to the microagent file within the repository
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
# Try all available providers in order
|
||||
errors = []
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self.get_service(provider)
|
||||
result = await service.get_microagent_content(repository, file_path)
|
||||
# If we got content, return it immediately
|
||||
if result:
|
||||
return result
|
||||
# If we got empty content, continue checking other providers
|
||||
logger.debug(
|
||||
f'No content found on {provider} for {repository}/{file_path}, trying other providers'
|
||||
)
|
||||
except ResourceNotFoundError:
|
||||
logger.debug(
|
||||
f'File not found on {provider} for {repository}/{file_path}, trying other providers'
|
||||
)
|
||||
continue
|
||||
except MicroagentParseError as e:
|
||||
# Parsing errors are specific to the provider, add to errors list
|
||||
errors.append(f'{provider.value}: {str(e)}')
|
||||
logger.warning(
|
||||
f'Error parsing microagent content from {provider} for {repository}: {e}'
|
||||
)
|
||||
except Exception as e:
|
||||
# For other errors (auth, rate limit, etc.), add to errors list
|
||||
errors.append(f'{provider.value}: {str(e)}')
|
||||
logger.warning(
|
||||
f'Error fetching microagent content from {provider} for {repository}: {e}'
|
||||
)
|
||||
|
||||
# If all providers failed or returned empty results, raise an error
|
||||
if errors:
|
||||
logger.error(
|
||||
f'Failed to fetch microagent content for {repository} with all available providers. Errors: {"; ".join(errors)}'
|
||||
)
|
||||
|
||||
# All providers returned empty content or file not found
|
||||
raise AuthenticationError(
|
||||
f'Microagent file {file_path} not found in {repository}'
|
||||
)
|
||||
|
||||
async def get_authenticated_git_url(
|
||||
self, repo_name: str, is_optional: bool = False
|
||||
) -> str:
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.microagent.microagent import BaseMicroagent
|
||||
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
@@ -33,7 +29,6 @@ class TaskType(str, Enum):
|
||||
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
|
||||
OPEN_ISSUE = 'OPEN_ISSUE'
|
||||
OPEN_PR = 'OPEN_PR'
|
||||
CREATE_MICROAGENT = 'CREATE_MICROAGENT'
|
||||
|
||||
|
||||
class OwnerType(str, Enum):
|
||||
@@ -120,12 +115,6 @@ class SuggestedTask(BaseModel):
|
||||
return template.render(issue_number=issue_number, repo=repo, **terms)
|
||||
|
||||
|
||||
class CreateMicroagent(BaseModel):
|
||||
repo: str
|
||||
git_provider: ProviderType | None = None
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class UserGitInfo(BaseModel):
|
||||
id: str
|
||||
login: str
|
||||
@@ -207,12 +196,6 @@ class ResourceNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class MicroagentParseError(ValueError):
|
||||
"""Raised when there is an error parsing a microagent file."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RequestMethod(Enum):
|
||||
POST = 'post'
|
||||
GET = 'get'
|
||||
@@ -232,216 +215,6 @@ class BaseGitService(ABC):
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
) -> tuple[Any, dict]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def _get_cursorrules_url(self, repository: str) -> str:
|
||||
"""Get the URL for checking .cursorrules file."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _get_microagents_directory_url(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> str:
|
||||
"""Get the URL for checking microagents directory."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_microagents_directory_params(self, microagents_path: str) -> dict | None:
|
||||
"""Get parameters for the microagents directory request. Return None if no parameters needed."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _is_valid_microagent_file(self, item: dict) -> bool:
|
||||
"""Check if an item represents a valid microagent file."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_file_name_from_item(self, item: dict) -> str:
|
||||
"""Extract file name from directory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_file_path_from_item(self, item: dict, microagents_path: str) -> str:
|
||||
"""Extract file path from directory item."""
|
||||
...
|
||||
|
||||
def _determine_microagents_path(self, repository_name: str) -> str:
|
||||
"""Determine the microagents directory path based on repository name."""
|
||||
actual_repo_name = repository_name.split('/')[-1]
|
||||
|
||||
# Check for special repository names that use a different structure
|
||||
if actual_repo_name == '.openhands' or actual_repo_name == 'openhands-config':
|
||||
# For repository name ".openhands", scan "microagents" folder
|
||||
return 'microagents'
|
||||
else:
|
||||
# Default behavior: look for .openhands/microagents directory
|
||||
return '.openhands/microagents'
|
||||
|
||||
def _create_microagent_response(
|
||||
self, file_name: str, path: str
|
||||
) -> MicroagentResponse:
|
||||
"""Create a microagent response from basic file information."""
|
||||
# Extract name without extension
|
||||
name = file_name.replace('.md', '').replace('.cursorrules', 'cursorrules')
|
||||
|
||||
return MicroagentResponse(
|
||||
name=name,
|
||||
path=path,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def _parse_microagent_content(
|
||||
self, content: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Parse microagent content and extract triggers using BaseMicroagent.load.
|
||||
|
||||
Args:
|
||||
content: Raw microagent file content
|
||||
file_path: Path to the file (used for microagent loading)
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
|
||||
Raises:
|
||||
MicroagentParseError: If the microagent file cannot be parsed
|
||||
"""
|
||||
try:
|
||||
# Use BaseMicroagent.load to properly parse the content
|
||||
# Create a temporary path object for the file
|
||||
temp_path = Path(file_path)
|
||||
|
||||
# Load the microagent using the existing infrastructure
|
||||
microagent = BaseMicroagent.load(path=temp_path, file_content=content)
|
||||
|
||||
# Extract triggers from the microagent's metadata
|
||||
triggers = microagent.metadata.triggers
|
||||
|
||||
# Return the MicroagentContentResponse
|
||||
return MicroagentContentResponse(
|
||||
content=microagent.content,
|
||||
path=file_path,
|
||||
triggers=triggers,
|
||||
git_provider=self.provider,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Error parsing microagent content for {file_path}: {str(e)}')
|
||||
raise MicroagentParseError(
|
||||
f'Failed to parse microagent file {file_path}: {str(e)}'
|
||||
)
|
||||
|
||||
async def _fetch_cursorrules_content(self, repository: str) -> Any | None:
|
||||
"""Fetch .cursorrules file content from the repository via API.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format specific to the provider
|
||||
|
||||
Returns:
|
||||
Raw API response content if .cursorrules file exists, None otherwise
|
||||
"""
|
||||
cursorrules_url = await self._get_cursorrules_url(repository)
|
||||
cursorrules_response, _ = await self._make_request(cursorrules_url)
|
||||
return cursorrules_response
|
||||
|
||||
async def _check_cursorrules_file(
|
||||
self, repository: str
|
||||
) -> MicroagentResponse | None:
|
||||
"""Check for .cursorrules file in the repository and return microagent response if found.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format specific to the provider
|
||||
|
||||
Returns:
|
||||
MicroagentResponse for .cursorrules file if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
cursorrules_content = await self._fetch_cursorrules_content(repository)
|
||||
if cursorrules_content:
|
||||
return self._create_microagent_response('.cursorrules', '.cursorrules')
|
||||
except ResourceNotFoundError:
|
||||
logger.debug(f'No .cursorrules file found in {repository}')
|
||||
except Exception as e:
|
||||
logger.warning(f'Error checking .cursorrules file in {repository}: {e}')
|
||||
|
||||
return None
|
||||
|
||||
async def _process_microagents_directory(
|
||||
self, repository: str, microagents_path: str
|
||||
) -> list[MicroagentResponse]:
|
||||
"""Process microagents directory and return list of microagent responses.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format specific to the provider
|
||||
microagents_path: Path to the microagents directory
|
||||
|
||||
Returns:
|
||||
List of MicroagentResponse objects found in the directory
|
||||
"""
|
||||
microagents = []
|
||||
|
||||
try:
|
||||
directory_url = await self._get_microagents_directory_url(
|
||||
repository, microagents_path
|
||||
)
|
||||
directory_params = self._get_microagents_directory_params(microagents_path)
|
||||
response, _ = await self._make_request(directory_url, directory_params)
|
||||
|
||||
# Handle different response structures
|
||||
items = response
|
||||
if isinstance(response, dict) and 'values' in response:
|
||||
# Bitbucket format
|
||||
items = response['values']
|
||||
elif isinstance(response, dict) and 'nodes' in response:
|
||||
# GraphQL format (if used)
|
||||
items = response['nodes']
|
||||
|
||||
for item in items:
|
||||
if self._is_valid_microagent_file(item):
|
||||
try:
|
||||
file_name = self._get_file_name_from_item(item)
|
||||
file_path = self._get_file_path_from_item(
|
||||
item, microagents_path
|
||||
)
|
||||
microagents.append(
|
||||
self._create_microagent_response(file_name, file_path)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Error processing microagent {item.get("name", "unknown")}: {str(e)}'
|
||||
)
|
||||
except ResourceNotFoundError:
|
||||
logger.info(
|
||||
f'No microagents directory found in {repository} at {microagents_path}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Error fetching microagents directory: {str(e)}')
|
||||
|
||||
return microagents
|
||||
|
||||
async def get_microagents(self, repository: str) -> list[MicroagentResponse]:
|
||||
"""Generic implementation of get_microagents that works across all providers.
|
||||
|
||||
Args:
|
||||
repository: Repository name in format specific to the provider
|
||||
|
||||
Returns:
|
||||
List of microagents found in the repository (without content for performance)
|
||||
"""
|
||||
microagents_path = self._determine_microagents_path(repository)
|
||||
microagents = []
|
||||
|
||||
# Step 1: Check for .cursorrules file
|
||||
cursorrules_microagent = await self._check_cursorrules_file(repository)
|
||||
if cursorrules_microagent:
|
||||
microagents.append(cursorrules_microagent)
|
||||
|
||||
# Step 2: Check for microagents directory and process .md files
|
||||
directory_microagents = await self._process_microagents_directory(
|
||||
repository, microagents_path
|
||||
)
|
||||
microagents.extend(directory_microagents)
|
||||
|
||||
return microagents
|
||||
|
||||
def _truncate_comment(
|
||||
self, comment_body: str, max_comment_length: int = 500
|
||||
) -> str:
|
||||
@@ -531,20 +304,6 @@ class GitService(Protocol):
|
||||
) -> list[Branch]:
|
||||
"""Search for branches within a repository"""
|
||||
|
||||
async def get_microagents(self, repository: str) -> list[MicroagentResponse]:
|
||||
"""Get microagents from a repository"""
|
||||
...
|
||||
|
||||
async def get_microagent_content(
|
||||
self, repository: str, file_path: str
|
||||
) -> MicroagentContentResponse:
|
||||
"""Get content of a specific microagent file
|
||||
|
||||
Returns:
|
||||
MicroagentContentResponse with parsed content and triggers
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_pr_details(self, repository: str, pr_number: int) -> dict:
|
||||
"""Get detailed information about a specific pull request/merge request
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from openhands.io.io import read_input, read_task, read_task_from_file
|
||||
from openhands.io.json import dumps, loads
|
||||
|
||||
__all__ = [
|
||||
'read_input',
|
||||
'read_task_from_file',
|
||||
'read_task',
|
||||
'dumps',
|
||||
'loads',
|
||||
]
|
||||
@@ -1,37 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
def read_input(cli_multiline_input: bool = False) -> str:
|
||||
"""Read input from user based on config settings."""
|
||||
if cli_multiline_input:
|
||||
print('Enter your message (enter "/exit" on a new line to finish):')
|
||||
lines = []
|
||||
while True:
|
||||
line = input('>> ').rstrip()
|
||||
if line == '/exit': # finish input
|
||||
break
|
||||
lines.append(line)
|
||||
return '\n'.join(lines)
|
||||
else:
|
||||
return input('>> ').rstrip()
|
||||
|
||||
|
||||
def read_task_from_file(file_path: str) -> str:
|
||||
"""Read task from the specified file."""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def read_task(args: argparse.Namespace, cli_multiline_input: bool) -> str:
|
||||
"""Read the task from the CLI args, file, or stdin."""
|
||||
# Determine the task
|
||||
task_str = ''
|
||||
if args.file:
|
||||
task_str = read_task_from_file(args.file)
|
||||
elif args.task:
|
||||
task_str = args.task
|
||||
elif not sys.stdin.isatty():
|
||||
task_str = read_input(cli_multiline_input)
|
||||
|
||||
return task_str
|
||||
@@ -1,75 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from json_repair import repair_json
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from openhands.core.exceptions import LLMResponseError
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import CmdOutputMetadata
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
|
||||
class OpenHandsJSONEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that handles datetime and event objects"""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, Event):
|
||||
return event_to_dict(obj)
|
||||
if isinstance(obj, Metrics):
|
||||
return obj.get()
|
||||
if isinstance(obj, ModelResponse):
|
||||
return obj.model_dump()
|
||||
if isinstance(obj, CmdOutputMetadata):
|
||||
return obj.model_dump()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
# Create a single reusable encoder instance
|
||||
_json_encoder = OpenHandsJSONEncoder()
|
||||
|
||||
|
||||
def dumps(obj, **kwargs) -> str:
|
||||
"""Serialize an object to str format"""
|
||||
if not kwargs:
|
||||
return _json_encoder.encode(obj)
|
||||
|
||||
# Create a copy of the kwargs to avoid modifying the original
|
||||
encoder_kwargs = kwargs.copy()
|
||||
|
||||
# If cls is specified, use it; otherwise use our custom encoder
|
||||
if 'cls' not in encoder_kwargs:
|
||||
encoder_kwargs['cls'] = OpenHandsJSONEncoder
|
||||
|
||||
return json.dumps(obj, **encoder_kwargs)
|
||||
|
||||
|
||||
def loads(json_str: str, **kwargs) -> Any:
|
||||
"""Create a JSON object from str"""
|
||||
try:
|
||||
return json.loads(json_str, **kwargs)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
depth = 0
|
||||
start = -1
|
||||
for i, char in enumerate(json_str):
|
||||
if char == '{':
|
||||
if depth == 0:
|
||||
start = i
|
||||
depth += 1
|
||||
elif char == '}':
|
||||
depth -= 1
|
||||
if depth == 0 and start != -1:
|
||||
response = json_str[start : i + 1]
|
||||
try:
|
||||
json_str = repair_json(response)
|
||||
return json.loads(json_str, **kwargs)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
raise LLMResponseError(
|
||||
'Invalid JSON in response. Please make sure the response is a valid JSON object.'
|
||||
) from e
|
||||
raise LLMResponseError('No valid JSON object found in response.')
|
||||
@@ -7,6 +7,7 @@
|
||||
# Tag: Legacy-V0
|
||||
# V1 replacement for this module lives in the Software Agent SDK.
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
@@ -245,8 +246,6 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
from openhands.io import json
|
||||
|
||||
messages_kwarg: (
|
||||
dict[str, Any] | Message | list[dict[str, Any]] | list[Message]
|
||||
) = []
|
||||
@@ -505,7 +504,6 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
from openhands.io import json
|
||||
|
||||
logger.debug(
|
||||
f'Model info: {json.dumps({"model": self.config.model, "base_url": self.config.base_url}, indent=2)}'
|
||||
|
||||
@@ -2,7 +2,6 @@ from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.error_collector import mcp_error_collector
|
||||
from openhands.mcp.tool import MCPClientTool
|
||||
from openhands.mcp.utils import (
|
||||
add_mcp_tools_to_agent,
|
||||
call_tool_mcp,
|
||||
convert_mcp_clients_to_tools,
|
||||
create_mcp_clients,
|
||||
@@ -16,6 +15,5 @@ __all__ = [
|
||||
'MCPClientTool',
|
||||
'fetch_mcp_tools_from_config',
|
||||
'call_tool_mcp',
|
||||
'add_mcp_tools_to_agent',
|
||||
'mcp_error_collector',
|
||||
]
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.memory.memory import Memory
|
||||
|
||||
|
||||
from mcp import McpError
|
||||
|
||||
@@ -21,7 +15,6 @@ from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.error_collector import mcp_error_collector
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils._redact_compat import (
|
||||
redact_text_secrets,
|
||||
redact_url_params,
|
||||
@@ -269,49 +262,3 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Obse
|
||||
name=action.name,
|
||||
arguments=action.arguments,
|
||||
)
|
||||
|
||||
|
||||
async def add_mcp_tools_to_agent(
|
||||
agent: 'Agent', runtime: Runtime, memory: 'Memory'
|
||||
) -> MCPConfig:
|
||||
"""Add MCP tools to an agent."""
|
||||
import sys
|
||||
|
||||
# Skip MCP tools on Windows
|
||||
if sys.platform == 'win32':
|
||||
logger.info('MCP functionality is disabled on Windows, skipping MCP tools')
|
||||
agent.set_mcp_tools([])
|
||||
return
|
||||
|
||||
assert runtime.runtime_initialized, (
|
||||
'Runtime must be initialized before adding MCP tools'
|
||||
)
|
||||
|
||||
extra_stdio_servers: dict[str, StdioMCPServer] = {}
|
||||
|
||||
# Add microagent MCP tools if available
|
||||
microagent_mcp_configs = memory.get_microagent_mcp_tools()
|
||||
for mcp_cfg in microagent_mcp_configs:
|
||||
for name, server in mcp_cfg.mcpServers.items():
|
||||
if isinstance(server, StdioMCPServer):
|
||||
if name not in extra_stdio_servers:
|
||||
extra_stdio_servers[name] = server
|
||||
logger.warning(f'Added microagent stdio server: {name}')
|
||||
else:
|
||||
logger.warning(
|
||||
f'Microagent MCP config contains non-stdio server {name}, not yet supported.'
|
||||
)
|
||||
|
||||
# Add the runtime as another MCP server
|
||||
updated_mcp_config = runtime.get_mcp_config(extra_stdio_servers or None)
|
||||
|
||||
# Fetch the MCP tools
|
||||
mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config)
|
||||
|
||||
tool_names = [tool['function']['name'] for tool in mcp_tools]
|
||||
logger.info(f'Loaded {len(mcp_tools)} MCP tools: {tool_names}')
|
||||
|
||||
# Set the MCP tools on the agent
|
||||
agent.set_mcp_tools(mcp_tools)
|
||||
|
||||
return updated_mcp_config
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# Memory Component
|
||||
|
||||
- Short Term History
|
||||
- Memory Condenser
|
||||
|
||||
## Short Term History
|
||||
- Short term history filters the event stream and computes the messages that are injected into the context
|
||||
- It filters out certain events of no interest for the Agent, such as AgentChangeStateObservation or NullAction/NullObservation
|
||||
- When the context window or the token limit set by the user is exceeded, history starts condensing: chunks of messages into summaries.
|
||||
- Each summary is then injected into the context, in the place of the respective chunk it summarizes
|
||||
|
||||
## Memory Condenser
|
||||
- Memory condenser is responsible for summarizing the chunks of events
|
||||
- It summarizes the earlier events first
|
||||
- It starts with the earliest agent actions and observations between two user messages
|
||||
- Then it does the same for later chunks of events between user messages
|
||||
- If there are no more agent events, it summarizes the user messages, this time one by one, if they're large enough and not immediately after an AgentFinishAction event (we assume those are tasks, potentially important)
|
||||
- Summaries are retrieved from the LLM as AgentSummarizeAction, and are saved in State.
|
||||
@@ -1,15 +0,0 @@
|
||||
import openhands.memory.condenser.impl # noqa F401 (we import this to get the condensers registered)
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condenser,
|
||||
get_condensation_metadata,
|
||||
View,
|
||||
Condensation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Condenser',
|
||||
'get_condensation_metadata',
|
||||
'CONDENSER_REGISTRY',
|
||||
'View',
|
||||
'Condensation',
|
||||
]
|
||||
@@ -1,193 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.condenser_config import CondenserConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.view import View
|
||||
|
||||
CONDENSER_METADATA_KEY = 'condenser_meta'
|
||||
"""Key identifying where metadata is stored in a `State` object's `extra_data` field."""
|
||||
|
||||
|
||||
def get_condensation_metadata(state: State) -> list[dict[str, Any]]:
|
||||
"""Utility function to retrieve a list of metadata batches from a `State`.
|
||||
|
||||
Args:
|
||||
state: The state to retrieve metadata from.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of metadata batches, each representing a condensation.
|
||||
"""
|
||||
if CONDENSER_METADATA_KEY in state.extra_data:
|
||||
return state.extra_data[CONDENSER_METADATA_KEY]
|
||||
return []
|
||||
|
||||
|
||||
CONDENSER_REGISTRY: dict[type[CondenserConfig], type[Condenser]] = {}
|
||||
"""Registry of condenser configurations to their corresponding condenser classes."""
|
||||
|
||||
|
||||
class Condensation(BaseModel):
|
||||
"""Produced by a condenser to indicate the history has been condensed."""
|
||||
|
||||
action: CondensationAction
|
||||
|
||||
|
||||
class Condenser(ABC):
|
||||
"""Abstract condenser interface.
|
||||
|
||||
Condensers take a list of `Event` objects and reduce them into a potentially smaller list.
|
||||
|
||||
Agents can use condensers to reduce the amount of events they need to consider when deciding which action to take. To use a condenser, agents can call the `condensed_history` method on the current `State` being considered and use the results instead of the full history.
|
||||
|
||||
If the condenser returns a `Condensation` instead of a `View`, the agent should return `Condensation.action` instead of producing its own action. On the next agent step the condenser will use that condensation event to produce a new `View`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._metadata_batch: dict[str, Any] = {}
|
||||
self._llm_metadata: dict[str, Any] = {}
|
||||
|
||||
def add_metadata(self, key: str, value: Any) -> None:
|
||||
"""Add information to the current metadata batch.
|
||||
|
||||
Any key/value pairs added to the metadata batch will be recorded in the `State` at the end of the current condensation.
|
||||
|
||||
Args:
|
||||
key: The key to store the metadata under.
|
||||
|
||||
value: The metadata to store.
|
||||
"""
|
||||
self._metadata_batch[key] = value
|
||||
|
||||
def write_metadata(self, state: State) -> None:
|
||||
"""Write the current batch of metadata to the `State`.
|
||||
|
||||
Resets the current metadata batch: any metadata added after this call will be stored in a new batch and written to the `State` at the end of the next condensation.
|
||||
"""
|
||||
if CONDENSER_METADATA_KEY not in state.extra_data:
|
||||
state.extra_data[CONDENSER_METADATA_KEY] = []
|
||||
if self._metadata_batch:
|
||||
state.extra_data[CONDENSER_METADATA_KEY].append(self._metadata_batch)
|
||||
|
||||
# Since the batch has been written, clear it for the next condensation
|
||||
self._metadata_batch = {}
|
||||
|
||||
@contextmanager
|
||||
def metadata_batch(self, state: State):
|
||||
"""Context manager to ensure batched metadata is always written to the `State`."""
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.write_metadata(state)
|
||||
|
||||
@abstractmethod
|
||||
def condense(self, View) -> View | Condensation:
|
||||
"""Condense a sequence of events into a potentially smaller list.
|
||||
|
||||
New condenser strategies should override this method to implement their own condensation logic. Call `self.add_metadata` in the implementation to record any relevant per-condensation diagnostic information.
|
||||
|
||||
Args:
|
||||
View: A view of the history containing all events that should be condensed.
|
||||
|
||||
Returns:
|
||||
View | Condensation: A condensed view of the events or an event indicating the history has been condensed.
|
||||
"""
|
||||
|
||||
def condensed_history(self, state: State) -> View | Condensation:
|
||||
"""Condense the state's history."""
|
||||
if hasattr(self, 'llm'):
|
||||
model_name = self.llm.config.model
|
||||
else:
|
||||
model_name = 'unknown'
|
||||
|
||||
self._llm_metadata = state.to_llm_metadata(
|
||||
model_name=model_name, agent_name='condenser'
|
||||
)
|
||||
with self.metadata_batch(state):
|
||||
return self.condense(state.view)
|
||||
|
||||
@property
|
||||
def llm_metadata(self) -> dict[str, Any]:
|
||||
"""Metadata to be passed to the LLM when using this condenser.
|
||||
|
||||
This metadata is used to provide context about the condensation process and can be used by the LLM to understand how the history was condensed.
|
||||
"""
|
||||
if not self._llm_metadata:
|
||||
logger.warning(
|
||||
'LLM metadata is empty. Ensure to set it in the condenser implementation.'
|
||||
)
|
||||
return self._llm_metadata
|
||||
|
||||
@classmethod
|
||||
def register_config(cls, configuration_type: type[CondenserConfig]) -> None:
|
||||
"""Register a new condenser configuration type.
|
||||
|
||||
Instances of registered configuration types can be passed to `from_config` to create instances of the corresponding condenser.
|
||||
|
||||
Args:
|
||||
configuration_type: The type of configuration used to create instances of the condenser.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration type is already registered.
|
||||
"""
|
||||
if configuration_type in CONDENSER_REGISTRY:
|
||||
raise ValueError(
|
||||
f'Condenser configuration {configuration_type} is already registered'
|
||||
)
|
||||
CONDENSER_REGISTRY[configuration_type] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: CondenserConfig, llm_registry: LLMRegistry
|
||||
) -> Condenser:
|
||||
"""Create a condenser from a configuration object.
|
||||
|
||||
Args:
|
||||
config: Configuration for the condenser.
|
||||
|
||||
Returns:
|
||||
Condenser: A condenser instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condenser type is not recognized.
|
||||
"""
|
||||
try:
|
||||
condenser_class = CONDENSER_REGISTRY[type(config)]
|
||||
return condenser_class.from_config(config, llm_registry)
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown condenser config: {config}')
|
||||
|
||||
|
||||
class RollingCondenser(Condenser, ABC):
|
||||
"""Base class for a specialized condenser strategy that applies condensation to a rolling history.
|
||||
|
||||
The rolling history is generated by `View.from_events`, which analyzes all events in the history and produces a `View` object representing what will be sent to the LLM.
|
||||
|
||||
If `should_condense` says so, the condenser is then responsible for generating a `Condensation` object from the `View` object. This will be added to the event history which should -- when given to `get_view` -- produce the condensed `View` to be passed to the LLM.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def should_condense(self, view: View) -> bool:
|
||||
"""Determine if a view should be condensed."""
|
||||
|
||||
@abstractmethod
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
"""Get the condensation from a view."""
|
||||
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
# If we trigger the condenser-specific condensation threshold, compute and return
|
||||
# the condensation.
|
||||
if self.should_condense(view):
|
||||
return self.get_condensation(view)
|
||||
|
||||
# Otherwise we're safe to just return the view.
|
||||
else:
|
||||
return view
|
||||
@@ -1,41 +0,0 @@
|
||||
from openhands.memory.condenser.impl.amortized_forgetting_condenser import (
|
||||
AmortizedForgettingCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.browser_output_condenser import (
|
||||
BrowserOutputCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.conversation_window_condenser import (
|
||||
ConversationWindowCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.llm_attention_condenser import (
|
||||
ImportantEventSelection,
|
||||
LLMAttentionCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.llm_summarizing_condenser import (
|
||||
LLMSummarizingCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.no_op_condenser import NoOpCondenser
|
||||
from openhands.memory.condenser.impl.observation_masking_condenser import (
|
||||
ObservationMaskingCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.pipeline import CondenserPipeline
|
||||
from openhands.memory.condenser.impl.recent_events_condenser import (
|
||||
RecentEventsCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.structured_summary_condenser import (
|
||||
StructuredSummaryCondenser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'AmortizedForgettingCondenser',
|
||||
'LLMAttentionCondenser',
|
||||
'ImportantEventSelection',
|
||||
'LLMSummarizingCondenser',
|
||||
'NoOpCondenser',
|
||||
'ObservationMaskingCondenser',
|
||||
'BrowserOutputCondenser',
|
||||
'RecentEventsCondenser',
|
||||
'StructuredSummaryCondenser',
|
||||
'CondenserPipeline',
|
||||
'ConversationWindowCondenser',
|
||||
]
|
||||
@@ -1,69 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
View,
|
||||
)
|
||||
|
||||
|
||||
class AmortizedForgettingCondenser(RollingCondenser):
|
||||
"""A condenser that maintains a condensed history and forgets old events when it grows too large."""
|
||||
|
||||
def __init__(self, max_size: int = 100, keep_first: int = 0):
|
||||
"""Initialize the condenser.
|
||||
|
||||
Args:
|
||||
max_size: Maximum size of history before forgetting.
|
||||
keep_first: Number of initial events to always keep.
|
||||
|
||||
Raises:
|
||||
ValueError: If keep_first is greater than max_size, keep_first is negative, or max_size is non-positive.
|
||||
"""
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
|
||||
super().__init__()
|
||||
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
target_size = self.max_size // 2
|
||||
head = view[: self.keep_first]
|
||||
|
||||
events_from_tail = target_size - len(head)
|
||||
tail = view[-events_from_tail:]
|
||||
|
||||
event_ids_to_keep = {event.id for event in head + tail}
|
||||
event_ids_to_forget = {event.id for event in view} - event_ids_to_keep
|
||||
|
||||
event = CondensationAction(
|
||||
forgotten_events_start_id=min(event_ids_to_forget),
|
||||
forgotten_events_end_id=max(event_ids_to_forget),
|
||||
)
|
||||
|
||||
return Condensation(action=event)
|
||||
|
||||
def should_condense(self, view: View) -> bool:
|
||||
return len(view) > self.max_size or view.unhandled_condensation_request
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: AmortizedForgettingCondenserConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> AmortizedForgettingCondenser:
|
||||
return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
AmortizedForgettingCondenser.register_config(AmortizedForgettingCondenserConfig)
|
||||
@@ -1,49 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import BrowserOutputCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
class BrowserOutputCondenser(Condenser):
|
||||
"""A condenser that masks the observations from browser outputs outside of a recent attention window.
|
||||
|
||||
The intent here is to mask just the browser outputs and leave everything else untouched. This is important because currently we provide screenshots and accessibility trees as input to the model for browser observations. These are really large and consume a lot of tokens without any benefits in performance. So we want to mask all such observations from all previous timesteps, and leave only the most recent one in context.
|
||||
"""
|
||||
|
||||
def __init__(self, attention_window: int = 1):
|
||||
self.attention_window = attention_window
|
||||
super().__init__()
|
||||
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Replace the content of browser observations outside of the attention window with a placeholder."""
|
||||
results: list[Event] = []
|
||||
cnt: int = 0
|
||||
for event in reversed(view):
|
||||
if (
|
||||
isinstance(event, BrowserOutputObservation)
|
||||
and cnt >= self.attention_window
|
||||
):
|
||||
results.append(
|
||||
AgentCondensationObservation(
|
||||
f'Visited URL {event.url}\nContent omitted'
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(event)
|
||||
if isinstance(event, BrowserOutputObservation):
|
||||
cnt += 1
|
||||
|
||||
return View(events=list(reversed(results)))
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: BrowserOutputCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> BrowserOutputCondenser:
|
||||
return BrowserOutputCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
BrowserOutputCondenser.register_config(BrowserOutputCondenserConfig)
|
||||
@@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import ConversationWindowCondenserConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import (
|
||||
CondensationAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
||||
|
||||
|
||||
class ConversationWindowCondenser(RollingCondenser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
"""Apply conversation window truncation similar to _apply_conversation_window.
|
||||
|
||||
This method:
|
||||
1. Identifies essential initial events (System Message, First User Message, Recall Observation)
|
||||
2. Keeps roughly half of the history
|
||||
3. Ensures action-observation pairs are preserved
|
||||
4. Returns a CondensationAction specifying which events to forget
|
||||
"""
|
||||
events = view.events
|
||||
|
||||
# Handle empty history
|
||||
if not events:
|
||||
# No events to condense
|
||||
action = CondensationAction(forgotten_event_ids=[])
|
||||
return Condensation(action=action)
|
||||
|
||||
# 1. Identify essential initial events
|
||||
system_message: SystemMessageAction | None = None
|
||||
first_user_msg: MessageAction | None = None
|
||||
recall_action: RecallAction | None = None
|
||||
recall_observation: Observation | None = None
|
||||
|
||||
# Find System Message (should be the first event, if it exists)
|
||||
system_message = next(
|
||||
(e for e in events if isinstance(e, SystemMessageAction)), None
|
||||
)
|
||||
|
||||
# Find First User Message
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if first_user_msg is None:
|
||||
logger.warning(
|
||||
'No first user message found in history during condensation.'
|
||||
)
|
||||
# Return empty condensation if no user message
|
||||
action = CondensationAction(forgotten_event_ids=[])
|
||||
return Condensation(action=action)
|
||||
|
||||
# Find the first user message index
|
||||
first_user_msg_index = -1
|
||||
for i, event in enumerate(events):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
first_user_msg_index = i
|
||||
break
|
||||
|
||||
# Find Recall Action and Observation related to the First User Message
|
||||
for i in range(first_user_msg_index + 1, len(events)):
|
||||
event = events[i]
|
||||
if (
|
||||
isinstance(event, RecallAction)
|
||||
and event.query == first_user_msg.content
|
||||
):
|
||||
recall_action = event
|
||||
# Look for its observation
|
||||
for j in range(i + 1, len(events)):
|
||||
obs_event = events[j]
|
||||
if (
|
||||
isinstance(obs_event, Observation)
|
||||
and obs_event.cause == recall_action.id
|
||||
):
|
||||
recall_observation = obs_event
|
||||
break
|
||||
break
|
||||
|
||||
# Collect essential events
|
||||
essential_events: list[int] = [] # Store event IDs
|
||||
if system_message:
|
||||
essential_events.append(system_message.id)
|
||||
essential_events.append(first_user_msg.id)
|
||||
if recall_action:
|
||||
essential_events.append(recall_action.id)
|
||||
if recall_observation:
|
||||
essential_events.append(recall_observation.id)
|
||||
|
||||
# 2. Determine which events to keep
|
||||
num_essential_events = len(essential_events)
|
||||
total_events = len(events)
|
||||
num_non_essential_events = total_events - num_essential_events
|
||||
|
||||
# Keep roughly half of the non-essential events
|
||||
num_recent_to_keep = max(1, num_non_essential_events // 2)
|
||||
|
||||
# Calculate the starting index for recent events to keep
|
||||
slice_start_index = total_events - num_recent_to_keep
|
||||
slice_start_index = max(0, slice_start_index)
|
||||
|
||||
# 3. Handle dangling observations at the start of the slice
|
||||
# Find the first non-observation event in the slice
|
||||
recent_events_slice = events[slice_start_index:]
|
||||
first_valid_event_index_in_slice = 0
|
||||
for i, event in enumerate(recent_events_slice):
|
||||
if not isinstance(event, Observation):
|
||||
first_valid_event_index_in_slice = i
|
||||
break
|
||||
else:
|
||||
# All events in the slice are observations
|
||||
first_valid_event_index_in_slice = len(recent_events_slice)
|
||||
|
||||
# Check if all events in the recent slice are dangling observations
|
||||
if first_valid_event_index_in_slice == len(recent_events_slice):
|
||||
logger.warning(
|
||||
'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.'
|
||||
)
|
||||
|
||||
# Calculate the actual index in the full events list
|
||||
first_valid_event_index = slice_start_index + first_valid_event_index_in_slice
|
||||
|
||||
if first_valid_event_index_in_slice > 0:
|
||||
logger.debug(
|
||||
f'Removed {first_valid_event_index_in_slice} dangling observation(s) '
|
||||
f'from the start of recent event slice.'
|
||||
)
|
||||
|
||||
# 4. Determine which events to keep and which to forget
|
||||
events_to_keep: set[int] = set(essential_events)
|
||||
|
||||
# Add recent events starting from first_valid_event_index
|
||||
for i in range(first_valid_event_index, total_events):
|
||||
events_to_keep.add(events[i].id)
|
||||
|
||||
# Calculate which events to forget
|
||||
all_event_ids = {e.id for e in events}
|
||||
forgotten_event_ids = sorted(all_event_ids - events_to_keep)
|
||||
|
||||
logger.info(
|
||||
f'ConversationWindowCondenser: Keeping {len(events_to_keep)} events, '
|
||||
f'forgetting {len(forgotten_event_ids)} events.'
|
||||
)
|
||||
|
||||
# Create the condensation action
|
||||
if forgotten_event_ids:
|
||||
# Use range if the forgotten events are contiguous
|
||||
if (
|
||||
len(forgotten_event_ids) > 1
|
||||
and forgotten_event_ids[-1] - forgotten_event_ids[0]
|
||||
== len(forgotten_event_ids) - 1
|
||||
):
|
||||
action = CondensationAction(
|
||||
forgotten_events_start_id=forgotten_event_ids[0],
|
||||
forgotten_events_end_id=forgotten_event_ids[-1],
|
||||
)
|
||||
else:
|
||||
action = CondensationAction(forgotten_event_ids=forgotten_event_ids)
|
||||
else:
|
||||
action = CondensationAction(forgotten_event_ids=[])
|
||||
|
||||
return Condensation(action=action)
|
||||
|
||||
def should_condense(self, view: View) -> bool:
|
||||
return view.unhandled_condensation_request
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
_config: ConversationWindowCondenserConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> ConversationWindowCondenser:
|
||||
return ConversationWindowCondenser()
|
||||
|
||||
|
||||
ConversationWindowCondenser.register_config(ConversationWindowCondenserConfig)
|
||||
@@ -1,140 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config.condenser_config import LLMAttentionCondenserConfig
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
View,
|
||||
)
|
||||
|
||||
|
||||
class ImportantEventSelection(BaseModel):
|
||||
"""Utility class for the `LLMAttentionCondenser` that forces the LLM to return a list of integers."""
|
||||
|
||||
ids: list[int]
|
||||
|
||||
|
||||
class LLMAttentionCondenser(RollingCondenser):
|
||||
"""Rolling condenser strategy that uses an LLM to select the most important events when condensing the history."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
max_size: int = 100,
|
||||
keep_first: int = 1,
|
||||
):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.llm = llm
|
||||
|
||||
# This condenser relies on the `response_schema` feature, which is not supported by all LLMs
|
||||
if not supports_response_schema(
|
||||
model=self.llm.config.model,
|
||||
custom_llm_provider=self.llm.config.custom_llm_provider,
|
||||
):
|
||||
raise ValueError(
|
||||
"The LLM model must support the 'response_schema' parameter to use the LLMAttentionCondenser."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
target_size = self.max_size // 2
|
||||
head_event_ids = [event.id for event in view.events[: self.keep_first]]
|
||||
|
||||
events_from_tail = target_size - len(head_event_ids)
|
||||
|
||||
message: str = """You will be given a list of actions, observations, and thoughts from a coding agent.
|
||||
Each item in the list has an identifier. Please sort the identifiers in order of how important the
|
||||
contents of the item are for the next step of the coding agent's task, from most important to least
|
||||
important."""
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[
|
||||
{'content': message, 'role': 'user'},
|
||||
*[
|
||||
{
|
||||
'content': f'<ID>{e.id}</ID>\n<CONTENT>{e.message}</CONTENT>',
|
||||
'role': 'user',
|
||||
}
|
||||
for e in view
|
||||
],
|
||||
],
|
||||
response_format={
|
||||
'type': 'json_schema',
|
||||
'json_schema': {
|
||||
'name': 'ImportantEventSelection',
|
||||
'schema': ImportantEventSelection.model_json_schema(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
response_ids = ImportantEventSelection.model_validate_json(
|
||||
response.choices[0].message.content
|
||||
).ids
|
||||
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
|
||||
# Filter out any IDs from the head and trim the results down
|
||||
response_ids = [
|
||||
response_id
|
||||
for response_id in response_ids
|
||||
if response_id not in head_event_ids
|
||||
][:events_from_tail]
|
||||
|
||||
# If the response IDs aren't _long_ enough, iterate backwards through the events and add any unfound IDs to the list.
|
||||
for event in reversed(view):
|
||||
if len(response_ids) >= events_from_tail:
|
||||
break
|
||||
if event.id not in response_ids:
|
||||
response_ids.append(event.id)
|
||||
|
||||
# Now that we've found the right number of events to keep, convert this into a list of events to forget.
|
||||
event = CondensationAction(
|
||||
forgotten_event_ids=[
|
||||
event.id
|
||||
for event in view
|
||||
if event.id not in response_ids and event.id not in head_event_ids
|
||||
],
|
||||
)
|
||||
|
||||
return Condensation(action=event)
|
||||
|
||||
def should_condense(self, view: View) -> bool:
|
||||
return len(view) > self.max_size or view.unhandled_condensation_request
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: LLMAttentionCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> LLMAttentionCondenser:
|
||||
# This condenser cannot take advantage of prompt caching. If it happens
|
||||
# to be set, we'll pay for the cache writes but never get a chance to
|
||||
# save on a read.
|
||||
llm_config = config.llm_config.model_copy()
|
||||
llm_config.caching_prompt = False
|
||||
|
||||
llm = llm_registry.get_llm('condenser', llm_config)
|
||||
|
||||
return LLMAttentionCondenser(
|
||||
llm=llm,
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
)
|
||||
|
||||
|
||||
LLMAttentionCondenser.register_config(LLMAttentionCondenserConfig)
|
||||
@@ -1,182 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import LLMSummarizingCondenserConfig
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
View,
|
||||
)
|
||||
|
||||
|
||||
class LLMSummarizingCondenser(RollingCondenser):
|
||||
"""A condenser that summarizes forgotten events.
|
||||
|
||||
Maintains a condensed history and forgets old events when it grows too large,
|
||||
keeping a special summarization event after the prefix that summarizes all previous summarizations
|
||||
and newly forgotten events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
max_size: int = 100,
|
||||
keep_first: int = 1,
|
||||
max_event_length: int = 10_000,
|
||||
):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.max_event_length = max_event_length
|
||||
self.llm = llm
|
||||
|
||||
super().__init__()
|
||||
|
||||
def _truncate(self, content: str) -> str:
|
||||
"""Truncate the content to fit within the specified maximum event length."""
|
||||
return truncate_content(content, max_chars=self.max_event_length)
|
||||
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
head = view[: self.keep_first]
|
||||
target_size = self.max_size // 2
|
||||
# Number of events to keep from the tail -- target size, minus however many
|
||||
# prefix events from the head, minus one for the summarization event
|
||||
events_from_tail = target_size - len(head) - 1
|
||||
|
||||
summary_event = (
|
||||
view[self.keep_first]
|
||||
if isinstance(view[self.keep_first], AgentCondensationObservation)
|
||||
else AgentCondensationObservation('No events summarized')
|
||||
)
|
||||
|
||||
# Identify events to be forgotten (those not in head or tail)
|
||||
forgotten_events = []
|
||||
for event in view[self.keep_first : -events_from_tail]:
|
||||
if not isinstance(event, AgentCondensationObservation):
|
||||
forgotten_events.append(event)
|
||||
|
||||
# Construct prompt for summarization
|
||||
prompt = """You are maintaining a context-aware state summary for an interactive agent.
|
||||
You will be given a list of events corresponding to actions taken by the agent, and the most recent previous summary if one exists.
|
||||
If the events being summarized contain ANY task-tracking, you MUST include a TASK_TRACKING section to maintain continuity.
|
||||
When referencing tasks make sure to preserve exact task IDs and statuses.
|
||||
|
||||
Track:
|
||||
|
||||
USER_CONTEXT: (Preserve essential user requirements, goals, and clarifications in concise form)
|
||||
|
||||
TASK_TRACKING: {Active tasks, their IDs and statuses - PRESERVE TASK IDs}
|
||||
|
||||
COMPLETED: (Tasks completed so far, with brief results)
|
||||
PENDING: (Tasks that still need to be done)
|
||||
CURRENT_STATE: (Current variables, data structures, or relevant state)
|
||||
|
||||
For code-specific tasks, also include:
|
||||
CODE_STATE: {File paths, function signatures, data structures}
|
||||
TESTS: {Failing cases, error messages, outputs}
|
||||
CHANGES: {Code edits, variable updates}
|
||||
DEPS: {Dependencies, imports, external calls}
|
||||
VERSION_CONTROL_STATUS: {Repository state, current branch, PR status, commit history}
|
||||
|
||||
PRIORITIZE:
|
||||
1. Adapt tracking format to match the actual task type
|
||||
2. Capture key user requirements and goals
|
||||
3. Distinguish between completed and pending tasks
|
||||
4. Keep all sections concise and relevant
|
||||
|
||||
SKIP: Tracking irrelevant details for the current task type
|
||||
|
||||
Example formats:
|
||||
|
||||
For code tasks:
|
||||
USER_CONTEXT: Fix FITS card float representation issue
|
||||
COMPLETED: Modified mod_float() in card.py, all tests passing
|
||||
PENDING: Create PR, update documentation
|
||||
CODE_STATE: mod_float() in card.py updated
|
||||
TESTS: test_format() passed
|
||||
CHANGES: str(val) replaces f"{val:.16G}"
|
||||
DEPS: None modified
|
||||
VERSION_CONTROL_STATUS: Branch: fix-float-precision, Latest commit: a1b2c3d
|
||||
|
||||
For other tasks:
|
||||
USER_CONTEXT: Write 20 haikus based on coin flip results
|
||||
COMPLETED: 15 haikus written for results [T,H,T,H,T,H,T,T,H,T,H,T,H,T,H]
|
||||
PENDING: 5 more haikus needed
|
||||
CURRENT_STATE: Last flip: Heads, Haiku count: 15/20"""
|
||||
|
||||
prompt += '\n\n'
|
||||
|
||||
# Add the previous summary if it exists. We'll always have a summary
|
||||
# event, but the types aren't precise enought to guarantee that it has a
|
||||
# message attribute.
|
||||
summary_event_content = self._truncate(
|
||||
summary_event.message if summary_event.message else ''
|
||||
)
|
||||
prompt += f'<PREVIOUS SUMMARY>\n{summary_event_content}\n</PREVIOUS SUMMARY>\n'
|
||||
|
||||
prompt += '\n\n'
|
||||
|
||||
# Add all events that are being forgotten. We use the string
|
||||
# representation defined by the event, and truncate it if necessary.
|
||||
for forgotten_event in forgotten_events:
|
||||
event_content = self._truncate(str(forgotten_event))
|
||||
prompt += f'<EVENT id={forgotten_event.id}>\n{event_content}\n</EVENT>\n'
|
||||
|
||||
prompt += 'Now summarize the events using the rules above.'
|
||||
|
||||
messages = [Message(role='user', content=[TextContent(text=prompt)])]
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=self.llm.format_messages_for_llm(messages),
|
||||
extra_body={'metadata': self.llm_metadata},
|
||||
)
|
||||
summary = response.choices[0].message.content
|
||||
|
||||
self.add_metadata('response', response.model_dump())
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
|
||||
return Condensation(
|
||||
action=CondensationAction(
|
||||
forgotten_events_start_id=min(event.id for event in forgotten_events),
|
||||
forgotten_events_end_id=max(event.id for event in forgotten_events),
|
||||
summary=summary,
|
||||
summary_offset=self.keep_first,
|
||||
)
|
||||
)
|
||||
|
||||
def should_condense(self, view: View) -> bool:
|
||||
return len(view) > self.max_size or view.unhandled_condensation_request
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: LLMSummarizingCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> LLMSummarizingCondenser:
|
||||
# This condenser cannot take advantage of prompt caching. If it happens
|
||||
# to be set, we'll pay for the cache writes but never get a chance to
|
||||
# save on a read.
|
||||
llm_config = config.llm_config.model_copy()
|
||||
llm_config.caching_prompt = False
|
||||
llm = llm_registry.get_llm('condenser', llm_config)
|
||||
|
||||
return LLMSummarizingCondenser(
|
||||
llm=llm,
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
max_event_length=config.max_event_length,
|
||||
)
|
||||
|
||||
|
||||
LLMSummarizingCondenser.register_config(LLMSummarizingCondenserConfig)
|
||||
@@ -1,22 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
class NoOpCondenser(Condenser):
|
||||
"""A condenser that does nothing to the event sequence."""
|
||||
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Returns the list of events unchanged."""
|
||||
return view
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: NoOpCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> NoOpCondenser:
|
||||
return NoOpCondenser()
|
||||
|
||||
|
||||
NoOpCondenser.register_config(NoOpCondenserConfig)
|
||||
@@ -1,39 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import ObservationMaskingCondenserConfig
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
class ObservationMaskingCondenser(Condenser):
|
||||
"""A condenser that masks the values of observations outside of a recent attention window."""
|
||||
|
||||
def __init__(self, attention_window: int = 5):
|
||||
self.attention_window = attention_window
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Replace the content of observations outside of the attention window with a placeholder."""
|
||||
results: list[Event] = []
|
||||
for i, event in enumerate(view):
|
||||
if isinstance(event, Observation) and i < len(view) - self.attention_window:
|
||||
results.append(AgentCondensationObservation('<MASKED>'))
|
||||
else:
|
||||
results.append(event)
|
||||
|
||||
return View(events=results)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ObservationMaskingCondenserConfig,
|
||||
llm_registry: LLMRegistry,
|
||||
) -> ObservationMaskingCondenser:
|
||||
return ObservationMaskingCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
ObservationMaskingCondenser.register_config(ObservationMaskingCondenserConfig)
|
||||
@@ -1,50 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.condenser_config import CondenserPipelineConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser
|
||||
from openhands.memory.view import View
|
||||
|
||||
|
||||
class CondenserPipeline(Condenser):
|
||||
"""Combines multiple condensers into a single condenser.
|
||||
|
||||
This is useful for creating a pipeline of condensers that can be chained together to achieve very specific condensation aims. Each condenser is run in sequence, passing the output view of one to the next, until we reach the end or a `CondensationAction` is returned instead.
|
||||
"""
|
||||
|
||||
def __init__(self, *condenser: Condenser) -> None:
|
||||
self.condensers = list(condenser)
|
||||
super().__init__()
|
||||
|
||||
@contextmanager
|
||||
def metadata_batch(self, state: State):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# The parent class assumes the metadata is stored in the "calling
|
||||
# condenser" -- since we're not threading a State through to each
|
||||
# step in the pipeline, we need to walk back through the pipeline
|
||||
# and manually collect the relevant metadata.
|
||||
for condenser in self.condensers:
|
||||
condenser.write_metadata(state)
|
||||
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
result: View | Condensation = view
|
||||
for condenser in self.condensers:
|
||||
result = condenser.condense(result)
|
||||
if isinstance(result, Condensation):
|
||||
break
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: CondenserPipelineConfig, llm_registry: LLMRegistry
|
||||
) -> CondenserPipeline:
|
||||
condensers = [Condenser.from_config(c, llm_registry) for c in config.condensers]
|
||||
return CondenserPipeline(*condensers)
|
||||
|
||||
|
||||
CondenserPipeline.register_config(CondenserPipelineConfig)
|
||||
@@ -1,31 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||
|
||||
|
||||
class RecentEventsCondenser(Condenser):
|
||||
"""A condenser that only keeps a certain number of the most recent events."""
|
||||
|
||||
def __init__(self, keep_first: int = 1, max_events: int = 10):
|
||||
self.keep_first = keep_first
|
||||
self.max_events = max_events
|
||||
|
||||
super().__init__()
|
||||
|
||||
def condense(self, view: View) -> View | Condensation:
|
||||
"""Keep only the most recent events (up to `max_events`)."""
|
||||
head = view[: self.keep_first]
|
||||
tail_length = max(0, self.max_events - len(head))
|
||||
tail = view[-tail_length:] if tail_length > 0 else []
|
||||
return View(events=head + tail)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: RecentEventsCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> RecentEventsCondenser:
|
||||
return RecentEventsCondenser(**config.model_dump(exclude={'type'}))
|
||||
|
||||
|
||||
RecentEventsCondenser.register_config(RecentEventsCondenserConfig)
|
||||
@@ -1,329 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.config.condenser_config import (
|
||||
StructuredSummaryCondenserConfig,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser.condenser import (
|
||||
Condensation,
|
||||
RollingCondenser,
|
||||
View,
|
||||
)
|
||||
|
||||
|
||||
class StateSummary(BaseModel):
|
||||
"""A structured representation summarizing the state of the agent and the task."""
|
||||
|
||||
# Required core fields
|
||||
user_context: str = Field(
|
||||
default='',
|
||||
description='Essential user requirements, goals, and clarifications in concise form.',
|
||||
)
|
||||
completed_tasks: str = Field(
|
||||
default='', description='List of tasks completed so far with brief results.'
|
||||
)
|
||||
pending_tasks: str = Field(
|
||||
default='', description='List of tasks that still need to be done.'
|
||||
)
|
||||
current_state: str = Field(
|
||||
default='',
|
||||
description='Current variables, data structures, or other relevant state information.',
|
||||
)
|
||||
|
||||
# Code state fields
|
||||
files_modified: str = Field(
|
||||
default='', description='List of files that have been created or modified.'
|
||||
)
|
||||
function_changes: str = Field(
|
||||
default='', description='List of functions that have been created or modified.'
|
||||
)
|
||||
data_structures: str = Field(
|
||||
default='', description='List of key data structures in use or modified.'
|
||||
)
|
||||
|
||||
# Test status fields
|
||||
tests_written: str = Field(
|
||||
default='',
|
||||
description='Whether tests have been written for the changes. True, false, or unknown.',
|
||||
)
|
||||
tests_passing: str = Field(
|
||||
default='',
|
||||
description='Whether all tests are currently passing. True, false, or unknown.',
|
||||
)
|
||||
failing_tests: str = Field(
|
||||
default='', description='List of names or descriptions of any failing tests.'
|
||||
)
|
||||
error_messages: str = Field(
|
||||
default='', description='List of key error messages encountered.'
|
||||
)
|
||||
|
||||
# Version control fields
|
||||
branch_created: str = Field(
|
||||
default='',
|
||||
description='Whether a branch has been created for this work. True, false, or unknown.',
|
||||
)
|
||||
branch_name: str = Field(
|
||||
default='', description='Name of the current working branch if known.'
|
||||
)
|
||||
commits_made: str = Field(
|
||||
default='',
|
||||
description='Whether any commits have been made. True, false, or unknown.',
|
||||
)
|
||||
pr_created: str = Field(
|
||||
default='',
|
||||
description='Whether a pull request has been created. True, false, or unknown.',
|
||||
)
|
||||
pr_status: str = Field(
|
||||
default='',
|
||||
description="Status of any pull request: 'draft', 'open', 'merged', 'closed', or 'unknown'.",
|
||||
)
|
||||
|
||||
# Other fields
|
||||
dependencies: str = Field(
|
||||
default='',
|
||||
description='List of dependencies or imports that have been added or modified.',
|
||||
)
|
||||
other_relevant_context: str = Field(
|
||||
default='',
|
||||
description="Any other important information that doesn't fit into the categories above.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tool_description(cls) -> dict[str, Any]:
|
||||
"""Description of a tool whose arguments are the fields of this class.
|
||||
|
||||
Can be given to an LLM to force structured generation.
|
||||
"""
|
||||
properties = {}
|
||||
|
||||
# Build properties dictionary from field information
|
||||
for field_name, field in cls.model_fields.items():
|
||||
description = field.description or ''
|
||||
|
||||
properties[field_name] = {'type': 'string', 'description': description}
|
||||
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'create_state_summary',
|
||||
'description': 'Creates a comprehensive summary of the current state of the interaction to preserve context when history grows too large. You must include non-empty values for user_context, completed_tasks, and pending_tasks.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': properties,
|
||||
'required': ['user_context', 'completed_tasks', 'pending_tasks'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Format the state summary in a clear way for Claude 3.7 Sonnet."""
|
||||
sections = [
|
||||
'# State Summary',
|
||||
'## Core Information',
|
||||
f'**User Context**: {self.user_context}',
|
||||
f'**Completed Tasks**: {self.completed_tasks}',
|
||||
f'**Pending Tasks**: {self.pending_tasks}',
|
||||
f'**Current State**: {self.current_state}',
|
||||
'## Code Changes',
|
||||
f'**Files Modified**: {self.files_modified}',
|
||||
f'**Function Changes**: {self.function_changes}',
|
||||
f'**Data Structures**: {self.data_structures}',
|
||||
f'**Dependencies**: {self.dependencies}',
|
||||
'## Testing Status',
|
||||
f'**Tests Written**: {self.tests_written}',
|
||||
f'**Tests Passing**: {self.tests_passing}',
|
||||
f'**Failing Tests**: {self.failing_tests}',
|
||||
f'**Error Messages**: {self.error_messages}',
|
||||
'## Version Control',
|
||||
f'**Branch Created**: {self.branch_created}',
|
||||
f'**Branch Name**: {self.branch_name}',
|
||||
f'**Commits Made**: {self.commits_made}',
|
||||
f'**PR Created**: {self.pr_created}',
|
||||
f'**PR Status**: {self.pr_status}',
|
||||
'## Additional Context',
|
||||
f'**Other Relevant Context**: {self.other_relevant_context}',
|
||||
]
|
||||
|
||||
# Join all sections with double newlines
|
||||
return '\n\n'.join(sections)
|
||||
|
||||
|
||||
class StructuredSummaryCondenser(RollingCondenser):
|
||||
"""A condenser that summarizes forgotten events.
|
||||
|
||||
Maintains a condensed history and forgets old events when it grows too large. Uses structured generation via function-calling to produce summaries that replace forgotten events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
max_size: int = 100,
|
||||
keep_first: int = 1,
|
||||
max_event_length: int = 10_000,
|
||||
):
|
||||
if keep_first >= max_size // 2:
|
||||
raise ValueError(
|
||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||
)
|
||||
if keep_first < 0:
|
||||
raise ValueError(f'keep_first ({keep_first}) cannot be negative')
|
||||
if max_size < 1:
|
||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||
|
||||
self.max_size = max_size
|
||||
self.keep_first = keep_first
|
||||
self.max_event_length = max_event_length
|
||||
self.llm = llm
|
||||
if not self.llm.is_function_calling_active():
|
||||
raise ValueError(
|
||||
'LLM must support function calling to use StructuredSummaryCondenser'
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def _truncate(self, content: str) -> str:
|
||||
"""Truncate the content to fit within the specified maximum event length."""
|
||||
return truncate_content(content, max_chars=self.max_event_length)
|
||||
|
||||
def get_condensation(self, view: View) -> Condensation:
|
||||
head = view[: self.keep_first]
|
||||
target_size = self.max_size // 2
|
||||
# Number of events to keep from the tail -- target size, minus however many
|
||||
# prefix events from the head, minus one for the summarization event
|
||||
events_from_tail = target_size - len(head) - 1
|
||||
|
||||
summary_event = (
|
||||
view[self.keep_first]
|
||||
if isinstance(view[self.keep_first], AgentCondensationObservation)
|
||||
else AgentCondensationObservation('No events summarized')
|
||||
)
|
||||
|
||||
# Identify events to be forgotten (those not in head or tail)
|
||||
forgotten_events = []
|
||||
for event in view[self.keep_first : -events_from_tail]:
|
||||
if not isinstance(event, AgentCondensationObservation):
|
||||
forgotten_events.append(event)
|
||||
|
||||
# Construct prompt for summarization
|
||||
prompt = """You are maintaining a context-aware state summary for an interactive software agent. This summary is critical because it:
|
||||
1. Preserves essential context when conversation history grows too large
|
||||
2. Prevents lost work when the session length exceeds token limits
|
||||
3. Helps maintain continuity across multiple interactions
|
||||
|
||||
You will be given:
|
||||
- A list of events (actions taken by the agent)
|
||||
- The most recent previous summary (if one exists)
|
||||
|
||||
Capture all relevant information, especially:
|
||||
- User requirements that were explicitly stated
|
||||
- Work that has been completed
|
||||
- Tasks that remain pending
|
||||
- Current state of code, variables, and data structures
|
||||
- The status of any version control operations"""
|
||||
|
||||
prompt += '\n\n'
|
||||
|
||||
# Add the previous summary if it exists. We'll always have a summary
|
||||
# event, but the types aren't precise enought to guarantee that it has a
|
||||
# message attribute.
|
||||
summary_event_content = self._truncate(
|
||||
summary_event.message if summary_event.message else ''
|
||||
)
|
||||
prompt += f'<PREVIOUS SUMMARY>\n{summary_event_content}\n</PREVIOUS SUMMARY>\n'
|
||||
|
||||
prompt += '\n\n'
|
||||
|
||||
# Add all events that are being forgotten. We use the string
|
||||
# representation defined by the event, and truncate it if necessary.
|
||||
for forgotten_event in forgotten_events:
|
||||
event_content = self._truncate(str(forgotten_event))
|
||||
prompt += f'<EVENT id={forgotten_event.id}>\n{event_content}\n</EVENT>\n'
|
||||
|
||||
messages = [Message(role='user', content=[TextContent(text=prompt)])]
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=self.llm.format_messages_for_llm(messages),
|
||||
tools=[StateSummary.tool_description()],
|
||||
tool_choice={
|
||||
'type': 'function',
|
||||
'function': {'name': 'create_state_summary'},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract the message containing tool calls
|
||||
message = response.choices[0].message
|
||||
|
||||
# Check if there are tool calls
|
||||
if not hasattr(message, 'tool_calls') or not message.tool_calls:
|
||||
raise ValueError('No tool calls found in response')
|
||||
|
||||
# Find the create_state_summary tool call
|
||||
summary_tool_call = None
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.function.name == 'create_state_summary':
|
||||
summary_tool_call = tool_call
|
||||
break
|
||||
|
||||
if not summary_tool_call:
|
||||
raise ValueError('create_state_summary tool call not found')
|
||||
|
||||
# Parse the arguments
|
||||
args_json = summary_tool_call.function.arguments
|
||||
args_dict = json.loads(args_json)
|
||||
|
||||
# Create a StateSummary object
|
||||
summary = StateSummary.model_validate(args_dict)
|
||||
|
||||
except (ValueError, AttributeError, KeyError, json.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f'Failed to parse summary tool call: {e}. Using empty summary.'
|
||||
)
|
||||
summary = StateSummary()
|
||||
|
||||
self.add_metadata('response', response.model_dump())
|
||||
self.add_metadata('metrics', self.llm.metrics.get())
|
||||
|
||||
return Condensation(
|
||||
action=CondensationAction(
|
||||
forgotten_events_start_id=min(event.id for event in forgotten_events),
|
||||
forgotten_events_end_id=max(event.id for event in forgotten_events),
|
||||
summary=str(summary),
|
||||
summary_offset=self.keep_first,
|
||||
)
|
||||
)
|
||||
|
||||
def should_condense(self, view: View) -> bool:
|
||||
return len(view) > self.max_size or view.unhandled_condensation_request
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: StructuredSummaryCondenserConfig, llm_registry: LLMRegistry
|
||||
) -> StructuredSummaryCondenser:
|
||||
# This condenser cannot take advantage of prompt caching. If it happens
|
||||
# to be set, we'll pay for the cache writes but never get a chance to
|
||||
# save on a read.
|
||||
llm_config = config.llm_config.model_copy()
|
||||
llm_config.caching_prompt = False
|
||||
llm = llm_registry.get_llm('condenser', llm_config)
|
||||
|
||||
return StructuredSummaryCondenser(
|
||||
llm=llm,
|
||||
max_size=config.max_size,
|
||||
keep_first=config.keep_first,
|
||||
max_event_length=config.max_event_length,
|
||||
)
|
||||
|
||||
|
||||
StructuredSummaryCondenser.register_config(StructuredSummaryCondenserConfig)
|
||||
@@ -1,898 +0,0 @@
|
||||
from typing import Generator
|
||||
|
||||
from litellm import ModelResponse
|
||||
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentThinkAction,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
TaskTrackingAction,
|
||||
)
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
AgentThinkObservation,
|
||||
BrowserOutputObservation,
|
||||
CmdOutputObservation,
|
||||
FileDownloadObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
IPythonRunCellObservation,
|
||||
LoopDetectionObservation,
|
||||
TaskTrackingObservation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.utils.prompt import (
|
||||
ConversationInstructions,
|
||||
PromptManager,
|
||||
RepositoryInfo,
|
||||
RuntimeInfo,
|
||||
)
|
||||
|
||||
|
||||
class ConversationMemory:
|
||||
"""Processes event history into a coherent conversation for the agent."""
|
||||
|
||||
def __init__(self, config: AgentConfig, prompt_manager: PromptManager):
|
||||
self.agent_config = config
|
||||
self.prompt_manager = prompt_manager
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_image_url(url: str | None) -> bool:
|
||||
"""Check if an image URL is valid and non-empty.
|
||||
|
||||
Args:
|
||||
url: The image URL to validate
|
||||
|
||||
Returns:
|
||||
True if the URL is valid, False otherwise
|
||||
"""
|
||||
return bool(url and url.strip())
|
||||
|
||||
def process_events(
|
||||
self,
|
||||
condensed_history: list[Event],
|
||||
initial_user_action: MessageAction,
|
||||
forgotten_event_ids: set[int] | None = None,
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
) -> list[Message]:
|
||||
"""Process state history into a list of messages for the LLM.
|
||||
|
||||
Ensures that tool call actions are processed correctly in function calling mode.
|
||||
|
||||
Args:
|
||||
condensed_history: The condensed history of events to convert
|
||||
initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly.
|
||||
forgotten_event_ids: Set of event IDs that have been forgotten/condensed. If the initial user action's ID
|
||||
is in this set, it will not be re-inserted to prevent re-execution of old instructions.
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
"""
|
||||
events = condensed_history
|
||||
# Default to empty set if not provided
|
||||
if forgotten_event_ids is None:
|
||||
forgotten_event_ids = set()
|
||||
|
||||
# Ensure the event list starts with SystemMessageAction, then MessageAction(source='user')
|
||||
self._ensure_system_message(events)
|
||||
self._ensure_initial_user_message(
|
||||
events, initial_user_action, forgotten_event_ids
|
||||
)
|
||||
|
||||
# log visual browsing status
|
||||
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||
|
||||
# Initialize empty messages list
|
||||
messages = []
|
||||
|
||||
# Process regular events
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
tool_call_id_to_message: dict[str, Message] = {}
|
||||
|
||||
for i, event in enumerate(events):
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
messages_to_add = self._process_action(
|
||||
action=event,
|
||||
pending_tool_call_action_messages=pending_tool_call_action_messages,
|
||||
vision_is_active=vision_is_active,
|
||||
)
|
||||
elif isinstance(event, Observation):
|
||||
messages_to_add = self._process_observation(
|
||||
obs=event,
|
||||
tool_call_id_to_message=tool_call_id_to_message,
|
||||
max_message_chars=max_message_chars,
|
||||
vision_is_active=vision_is_active,
|
||||
enable_som_visual_browsing=self.agent_config.enable_som_visual_browsing,
|
||||
current_index=i,
|
||||
events=events,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {type(event)}')
|
||||
|
||||
# Check pending tool call action messages and see if they are complete
|
||||
_response_ids_to_remove = []
|
||||
for (
|
||||
response_id,
|
||||
pending_message,
|
||||
) in pending_tool_call_action_messages.items():
|
||||
assert pending_message.tool_calls is not None, (
|
||||
'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. '
|
||||
f'Pending message: {pending_message}'
|
||||
)
|
||||
if all(
|
||||
tool_call.id in tool_call_id_to_message
|
||||
for tool_call in pending_message.tool_calls
|
||||
):
|
||||
# If complete:
|
||||
# -- 1. Add the message that **initiated** the tool calls
|
||||
messages_to_add.append(pending_message)
|
||||
# -- 2. Add the tool calls **results***
|
||||
for tool_call in pending_message.tool_calls:
|
||||
messages_to_add.append(tool_call_id_to_message[tool_call.id])
|
||||
tool_call_id_to_message.pop(tool_call.id)
|
||||
_response_ids_to_remove.append(response_id)
|
||||
# Cleanup the processed pending tool messages
|
||||
for response_id in _response_ids_to_remove:
|
||||
pending_tool_call_action_messages.pop(response_id)
|
||||
|
||||
messages += messages_to_add
|
||||
|
||||
# Apply final filtering so that the messages in context don't have unmatched tool calls
|
||||
# and tool responses, for example
|
||||
messages = list(ConversationMemory._filter_unmatched_tool_calls(messages))
|
||||
|
||||
# Apply final formatting
|
||||
messages = self._apply_user_message_formatting(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def _apply_user_message_formatting(self, messages: list[Message]) -> list[Message]:
|
||||
"""Applies formatting rules, such as adding newlines between consecutive user messages."""
|
||||
formatted_messages = []
|
||||
prev_role = None
|
||||
for msg in messages:
|
||||
# Add double newline between consecutive user messages
|
||||
if msg.role == 'user' and prev_role == 'user' and len(msg.content) > 0:
|
||||
# Find the first TextContent in the message to add newlines
|
||||
for content_item in msg.content:
|
||||
if isinstance(content_item, TextContent):
|
||||
# Prepend two newlines to ensure visual separation
|
||||
content_item.text = '\n\n' + content_item.text
|
||||
break
|
||||
formatted_messages.append(msg)
|
||||
prev_role = msg.role # Update prev_role after processing each message
|
||||
return formatted_messages
|
||||
|
||||
def _process_action(
|
||||
self,
|
||||
action: Action,
|
||||
pending_tool_call_action_messages: dict[str, Message],
|
||||
vision_is_active: bool = False,
|
||||
) -> list[Message]:
|
||||
"""Converts an action into a message format that can be sent to the LLM.
|
||||
|
||||
This method handles different types of actions and formats them appropriately:
|
||||
1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish:
|
||||
- In function calling mode: Stores the LLM's response in pending_tool_call_action_messages
|
||||
- In non-function calling mode: Creates a message with the action string
|
||||
2. For MessageActions: Creates a message with the text content and optional image content
|
||||
|
||||
Args:
|
||||
action: The action to convert. Can be one of:
|
||||
- CmdRunAction: For executing bash commands
|
||||
- IPythonRunCellAction: For running IPython code
|
||||
- FileEditAction: For editing files
|
||||
- FileReadAction: For reading files using openhands-aci commands
|
||||
- BrowseInteractiveAction: For browsing the web
|
||||
- AgentFinishAction: For ending the interaction
|
||||
- MessageAction: For sending messages
|
||||
- MCPAction: For interacting with the MCP server
|
||||
pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages.
|
||||
Used in function calling mode to track tool calls that are waiting for their results.
|
||||
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
||||
|
||||
Returns:
|
||||
list[Message]: A list containing the formatted message(s) for the action.
|
||||
May be empty if the action is handled as a tool call in function calling mode.
|
||||
|
||||
Note:
|
||||
In function calling mode, tool-based actions are stored in pending_tool_call_action_messages
|
||||
rather than being returned immediately. They will be processed later when all corresponding
|
||||
tool call results are available.
|
||||
"""
|
||||
# create a regular message from an event
|
||||
if isinstance(
|
||||
action,
|
||||
(
|
||||
AgentDelegateAction,
|
||||
AgentThinkAction,
|
||||
IPythonRunCellAction,
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
MCPAction,
|
||||
TaskTrackingAction,
|
||||
),
|
||||
) or (isinstance(action, CmdRunAction) and action.source == 'agent'):
|
||||
tool_metadata = action.tool_call_metadata
|
||||
|
||||
# Allow user actions to skip tool metadata validation
|
||||
if action.source == 'user' and tool_metadata is None:
|
||||
# For user-initiated actions without tool metadata, create a simple message
|
||||
return [
|
||||
Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(
|
||||
text=f'User requested to read file: {str(action)}'
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
assert tool_metadata is not None, (
|
||||
'Tool call metadata should NOT be None when function calling is enabled for agent actions. Action: '
|
||||
+ str(action)
|
||||
)
|
||||
|
||||
llm_response: ModelResponse = tool_metadata.model_response
|
||||
assistant_msg = getattr(llm_response.choices[0], 'message')
|
||||
|
||||
# Add the LLM message (assistant) that initiated the tool calls
|
||||
# (overwrites any previous message with the same response_id)
|
||||
pending_tool_call_action_messages[llm_response.id] = Message(
|
||||
role=getattr(assistant_msg, 'role', 'assistant'),
|
||||
# tool call content SHOULD BE a string
|
||||
content=[TextContent(text=assistant_msg.content)]
|
||||
if assistant_msg.content and assistant_msg.content.strip()
|
||||
else [],
|
||||
tool_calls=assistant_msg.tool_calls,
|
||||
)
|
||||
return []
|
||||
elif isinstance(action, AgentFinishAction):
|
||||
role = 'user' if action.source == 'user' else 'assistant'
|
||||
|
||||
# when agent finishes, it has tool_metadata
|
||||
# which has already been executed, and it doesn't have a response
|
||||
# when the user finishes (/exit), we don't have tool_metadata
|
||||
tool_metadata = action.tool_call_metadata
|
||||
if tool_metadata is not None:
|
||||
# take the response message from the tool call
|
||||
assistant_msg = getattr(
|
||||
tool_metadata.model_response.choices[0], 'message'
|
||||
)
|
||||
content = assistant_msg.content or ''
|
||||
|
||||
# save content if any, to thought
|
||||
if action.thought:
|
||||
if action.thought != content:
|
||||
action.thought += '\n' + content
|
||||
else:
|
||||
action.thought = content
|
||||
|
||||
# remove the tool call metadata
|
||||
action.tool_call_metadata = None
|
||||
if role not in ('user', 'system', 'assistant', 'tool'):
|
||||
raise ValueError(f'Invalid role: {role}')
|
||||
return [
|
||||
Message(
|
||||
role=role, # type: ignore[arg-type]
|
||||
content=[TextContent(text=action.thought)],
|
||||
)
|
||||
]
|
||||
elif isinstance(action, MessageAction):
|
||||
role = 'user' if action.source == 'user' else 'assistant'
|
||||
content = [TextContent(text=action.content or '')]
|
||||
if action.image_urls:
|
||||
if role == 'user':
|
||||
for idx, url in enumerate(action.image_urls):
|
||||
# Only add descriptive text if vision is active
|
||||
if vision_is_active:
|
||||
content.append(TextContent(text=f'Image {idx + 1}:'))
|
||||
content.append(ImageContent(image_urls=[url]))
|
||||
else:
|
||||
content.append(ImageContent(image_urls=action.image_urls))
|
||||
if role not in ('user', 'system', 'assistant', 'tool'):
|
||||
raise ValueError(f'Invalid role: {role}')
|
||||
return [
|
||||
Message(
|
||||
role=role, # type: ignore[arg-type]
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
elif isinstance(action, CmdRunAction) and action.source == 'user':
|
||||
content = [
|
||||
TextContent(text=f'User executed the command:\n{action.command}')
|
||||
]
|
||||
return [
|
||||
Message(
|
||||
role='user', # Always user for CmdRunAction
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
elif isinstance(action, SystemMessageAction):
|
||||
# Convert SystemMessageAction to a system message
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content=[TextContent(text=action.content)],
|
||||
# Include tools if function calling is enabled
|
||||
tool_calls=None,
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def _process_observation(
|
||||
self,
|
||||
obs: Observation,
|
||||
tool_call_id_to_message: dict[str, Message],
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
enable_som_visual_browsing: bool = False,
|
||||
current_index: int = 0,
|
||||
events: list[Event] | None = None,
|
||||
) -> list[Message]:
|
||||
"""Converts an observation into a message format that can be sent to the LLM.
|
||||
|
||||
This method handles different types of observations and formats them appropriately:
|
||||
- CmdOutputObservation: Formats command execution results with exit codes
|
||||
- IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images
|
||||
- FileEditObservation: Formats file editing results
|
||||
- FileReadObservation: Formats file reading results from openhands-aci
|
||||
- AgentDelegateObservation: Formats results from delegated agent tasks
|
||||
- ErrorObservation: Formats error messages from failed actions
|
||||
- UserRejectObservation: Formats user rejection messages
|
||||
- FileDownloadObservation: Formats the result of a browsing action that opened/downloaded a file
|
||||
|
||||
In function calling mode, observations with tool_call_metadata are stored in
|
||||
tool_call_id_to_message for later processing instead of being returned immediately.
|
||||
|
||||
Args:
|
||||
obs: The observation to convert
|
||||
tool_call_id_to_message: Dictionary mapping tool call IDs to their corresponding messages (used in function calling mode)
|
||||
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
|
||||
current_index: The index of the current event in the events list (for deduplication)
|
||||
events: The list of all events (for deduplication)
|
||||
|
||||
Returns:
|
||||
list[Message]: A list containing the formatted message(s) for the observation.
|
||||
May be empty if the observation is handled as a tool response in function calling mode.
|
||||
|
||||
Raises:
|
||||
ValueError: If the observation type is unknown
|
||||
"""
|
||||
message: Message
|
||||
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
# Note: CmdOutputObservation content is already truncated at initialization,
|
||||
# and the observation content should not have been modified after initialization
|
||||
# we keep this truncation for backwards compatibility for a time
|
||||
if obs.tool_call_metadata is None:
|
||||
# if it doesn't have tool call metadata, it was triggered by a user action
|
||||
text = truncate_content(
|
||||
f'\nObserved result of command executed by user:\n{obs.to_agent_observation()}',
|
||||
max_message_chars,
|
||||
)
|
||||
else:
|
||||
text = truncate_content(obs.to_agent_observation(), max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, MCPObservation):
|
||||
# logger.warning(f'MCPObservation: {obs}')
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, IPythonRunCellObservation):
|
||||
text = obs.content
|
||||
# Clean up any remaining base64 images in text content
|
||||
splitted = text.split('\n')
|
||||
for i, line in enumerate(splitted):
|
||||
if ' already displayed to user'
|
||||
)
|
||||
text = '\n'.join(splitted)
|
||||
text = truncate_content(text, max_message_chars)
|
||||
|
||||
# Create message content with text
|
||||
content: list[TextContent | ImageContent] = [TextContent(text=text)]
|
||||
|
||||
# Add image URLs if available
|
||||
if obs.image_urls:
|
||||
# Filter out empty or invalid image URLs
|
||||
valid_image_urls = [
|
||||
url for url in obs.image_urls if self._is_valid_image_url(url)
|
||||
]
|
||||
invalid_count = len(obs.image_urls) - len(valid_image_urls)
|
||||
|
||||
if valid_image_urls:
|
||||
content.append(ImageContent(image_urls=valid_image_urls))
|
||||
# Only add explanatory text if vision is active
|
||||
if vision_is_active and invalid_count > 0:
|
||||
# Add text indicating some images were filtered
|
||||
content[
|
||||
0
|
||||
].text += f'\n\nNote: {invalid_count} invalid or empty image(s) were filtered from this output. The agent may need to use alternative methods to access visual information.' # type: ignore[union-attr]
|
||||
else:
|
||||
logger.debug(
|
||||
'IPython observation has image URLs but none are valid'
|
||||
)
|
||||
# Only add explanatory text if vision is active
|
||||
if vision_is_active:
|
||||
# Add text indicating all images were filtered
|
||||
content[
|
||||
0
|
||||
].text += f'\n\nNote: All {len(obs.image_urls)} image(s) in this output were invalid or empty and have been filtered. The agent should use alternative methods to access visual information.' # type: ignore[union-attr]
|
||||
|
||||
message = Message(role='user', content=content)
|
||||
elif isinstance(obs, FileEditObservation):
|
||||
text = truncate_content(str(obs), max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, FileReadObservation):
|
||||
message = Message(
|
||||
role='user', content=[TextContent(text=obs.content)]
|
||||
) # Content is already truncated by openhands-aci
|
||||
elif isinstance(obs, BrowserOutputObservation):
|
||||
text = obs.content
|
||||
content = [TextContent(text=text)]
|
||||
if (
|
||||
obs.trigger_by_action == ActionType.BROWSE_INTERACTIVE
|
||||
and enable_som_visual_browsing
|
||||
):
|
||||
# Only add descriptive text if vision is active
|
||||
if vision_is_active:
|
||||
# We know content[0] is TextContent since we just created it above
|
||||
text_content = content[0]
|
||||
assert isinstance(text_content, TextContent)
|
||||
text_content.text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. However, the Accessibility tree contains information from the entire webpage.)\n'
|
||||
|
||||
# Determine which image to use and validate it
|
||||
image_url = None
|
||||
image_type = None
|
||||
if obs.set_of_marks is not None and len(obs.set_of_marks) > 0:
|
||||
image_url = obs.set_of_marks
|
||||
image_type = 'set of marks'
|
||||
elif obs.screenshot is not None and len(obs.screenshot) > 0:
|
||||
image_url = obs.screenshot
|
||||
image_type = 'screenshot'
|
||||
|
||||
# Always add ImageContent if we have a valid image URL
|
||||
if self._is_valid_image_url(image_url):
|
||||
content.append(ImageContent(image_urls=[image_url])) # type: ignore[list-item]
|
||||
logger.debug(f'Adding {image_type} for browsing')
|
||||
else:
|
||||
if vision_is_active and image_url:
|
||||
logger.warning(
|
||||
f'Invalid image URL format for {image_type}: {image_url[:50]}...'
|
||||
)
|
||||
# Add text indicating the image was filtered (only if vision is active)
|
||||
content[
|
||||
0
|
||||
].text += f'\n\nNote: The {image_type} for this webpage was invalid or empty and has been filtered. The agent should use alternative methods to access visual information about the webpage.' # type: ignore[union-attr]
|
||||
elif vision_is_active and not image_url:
|
||||
logger.debug(
|
||||
'Vision enabled for browsing, but no valid image available'
|
||||
)
|
||||
# Add text indicating no image was available (only if vision is active)
|
||||
content[
|
||||
0
|
||||
].text += '\n\nNote: No visual information (screenshot or set of marks) is available for this webpage. The agent should rely on the text content above.' # type: ignore[union-attr]
|
||||
|
||||
message = Message(role='user', content=content)
|
||||
elif isinstance(obs, AgentDelegateObservation):
|
||||
text = truncate_content(
|
||||
obs.outputs.get('content', obs.content),
|
||||
max_message_chars,
|
||||
)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, AgentThinkObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, TaskTrackingObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, ErrorObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
text += '\n[Error occurred in processing last action]'
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, UserRejectObservation):
|
||||
text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
|
||||
text += '\n[Last action has been rejected by the user]'
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, AgentCondensationObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, FileDownloadObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, LoopDetectionObservation):
|
||||
# LoopRecovery should not be observed by llm, handled internally.
|
||||
return []
|
||||
elif (
|
||||
isinstance(obs, RecallObservation)
|
||||
and self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||
# everything is optional, check if they are present
|
||||
if obs.repo_name or obs.repo_directory:
|
||||
repo_info = RepositoryInfo(
|
||||
repo_name=obs.repo_name or '',
|
||||
repo_directory=obs.repo_directory or '',
|
||||
branch_name=obs.repo_branch or None,
|
||||
)
|
||||
else:
|
||||
repo_info = None
|
||||
|
||||
date = obs.date
|
||||
|
||||
if obs.runtime_hosts or obs.additional_agent_instructions:
|
||||
runtime_info = RuntimeInfo(
|
||||
available_hosts=obs.runtime_hosts,
|
||||
additional_agent_instructions=obs.additional_agent_instructions,
|
||||
date=date,
|
||||
custom_secrets_descriptions=obs.custom_secrets_descriptions,
|
||||
working_dir=obs.working_dir,
|
||||
)
|
||||
else:
|
||||
runtime_info = RuntimeInfo(
|
||||
date=date,
|
||||
custom_secrets_descriptions=obs.custom_secrets_descriptions,
|
||||
working_dir=obs.working_dir,
|
||||
)
|
||||
|
||||
conversation_instructions = None
|
||||
|
||||
if obs.conversation_instructions:
|
||||
conversation_instructions = ConversationInstructions(
|
||||
content=obs.conversation_instructions
|
||||
)
|
||||
|
||||
repo_instructions = (
|
||||
obs.repo_instructions if obs.repo_instructions else ''
|
||||
)
|
||||
|
||||
# Have some meaningful content before calling the template
|
||||
has_repo_info = repo_info is not None and (
|
||||
repo_info.repo_name or repo_info.repo_directory
|
||||
)
|
||||
has_runtime_info = runtime_info is not None and (
|
||||
runtime_info.date or runtime_info.custom_secrets_descriptions
|
||||
)
|
||||
has_repo_instructions = bool(repo_instructions.strip())
|
||||
has_conversation_instructions = conversation_instructions is not None
|
||||
|
||||
# Filter and process microagent knowledge
|
||||
filtered_agents = []
|
||||
if obs.microagent_knowledge:
|
||||
# Exclude disabled microagents
|
||||
filtered_agents = [
|
||||
agent
|
||||
for agent in obs.microagent_knowledge
|
||||
if agent.name not in self.agent_config.disabled_microagents
|
||||
]
|
||||
|
||||
has_microagent_knowledge = bool(filtered_agents)
|
||||
|
||||
# Generate appropriate content based on what is present
|
||||
message_content: list[TextContent | ImageContent] = []
|
||||
|
||||
# Build the workspace context information
|
||||
if (
|
||||
has_repo_info
|
||||
or has_runtime_info
|
||||
or has_repo_instructions
|
||||
or has_conversation_instructions
|
||||
):
|
||||
formatted_workspace_text = (
|
||||
self.prompt_manager.build_workspace_context(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
conversation_instructions=conversation_instructions,
|
||||
repo_instructions=repo_instructions,
|
||||
)
|
||||
)
|
||||
message_content.append(TextContent(text=formatted_workspace_text))
|
||||
|
||||
# Add microagent knowledge if present
|
||||
if has_microagent_knowledge:
|
||||
formatted_microagent_text = (
|
||||
self.prompt_manager.build_microagent_info(
|
||||
triggered_agents=filtered_agents,
|
||||
)
|
||||
)
|
||||
message_content.append(TextContent(text=formatted_microagent_text))
|
||||
|
||||
# Return the combined message if we have any content
|
||||
if message_content:
|
||||
message = Message(role='user', content=message_content)
|
||||
else:
|
||||
return []
|
||||
elif obs.recall_type == RecallType.KNOWLEDGE:
|
||||
# Use prompt manager to build the microagent info
|
||||
# First, filter out agents that appear in earlier RecallObservations
|
||||
filtered_agents = self._filter_agents_in_microagent_obs(
|
||||
obs, current_index, events or []
|
||||
)
|
||||
|
||||
# Create and return a message if there is microagent knowledge to include
|
||||
if filtered_agents:
|
||||
# Exclude disabled microagents
|
||||
filtered_agents = [
|
||||
agent
|
||||
for agent in filtered_agents
|
||||
if agent.name not in self.agent_config.disabled_microagents
|
||||
]
|
||||
|
||||
# Only proceed if we still have agents after filtering out disabled ones
|
||||
if filtered_agents:
|
||||
formatted_text = self.prompt_manager.build_microagent_info(
|
||||
triggered_agents=filtered_agents,
|
||||
)
|
||||
|
||||
return [
|
||||
Message(
|
||||
role='user', content=[TextContent(text=formatted_text)]
|
||||
)
|
||||
]
|
||||
|
||||
# Return empty list if no microagents to include or all were disabled
|
||||
return []
|
||||
elif (
|
||||
isinstance(obs, RecallObservation)
|
||||
and not self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
# If prompt extensions are disabled, we don't add any additional info
|
||||
# TODO: test this
|
||||
return []
|
||||
else:
|
||||
# If an observation message is not returned, it will cause an error
|
||||
# when the LLM tries to return the next message
|
||||
raise ValueError(f'Unknown observation type: {type(obs)}')
|
||||
|
||||
# Update the message as tool response properly
|
||||
if (tool_call_metadata := getattr(obs, 'tool_call_metadata', None)) is not None:
|
||||
tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message(
|
||||
role='tool',
|
||||
content=message.content,
|
||||
tool_call_id=tool_call_metadata.tool_call_id,
|
||||
name=tool_call_metadata.function_name,
|
||||
)
|
||||
# No need to return the observation message
|
||||
# because it will be added by get_action_message when all the corresponding
|
||||
# tool calls in the SAME request are processed
|
||||
return []
|
||||
|
||||
return [message]
|
||||
|
||||
def apply_prompt_caching(self, messages: list[Message]) -> None:
|
||||
"""Applies caching breakpoints to the messages.
|
||||
|
||||
For new Anthropic API, we only need to mark the last user or tool message as cacheable.
|
||||
"""
|
||||
if len(messages) > 0 and messages[0].role == 'system':
|
||||
messages[0].content[-1].cache_prompt = True
|
||||
# NOTE: this is only needed for anthropic
|
||||
for message in reversed(messages):
|
||||
if message.role in ('user', 'tool'):
|
||||
message.content[
|
||||
-1
|
||||
].cache_prompt = True # Last item inside the message content
|
||||
break
|
||||
|
||||
def _filter_agents_in_microagent_obs(
|
||||
self, obs: RecallObservation, current_index: int, events: list[Event]
|
||||
) -> list[MicroagentKnowledge]:
|
||||
"""Filter out agents that appear in earlier RecallObservations.
|
||||
|
||||
Args:
|
||||
obs: The current RecallObservation to filter
|
||||
current_index: The index of the current event in the events list
|
||||
events: The list of all events
|
||||
|
||||
Returns:
|
||||
list[MicroagentKnowledge]: The filtered list of microagent knowledge
|
||||
"""
|
||||
if obs.recall_type != RecallType.KNOWLEDGE:
|
||||
return obs.microagent_knowledge
|
||||
|
||||
# For each agent in the current microagent observation, check if it appears in any earlier microagent observation
|
||||
filtered_agents = []
|
||||
for agent in obs.microagent_knowledge:
|
||||
# Keep this agent if it doesn't appear in any earlier observation
|
||||
# that is, if this is the first microagent observation with this microagent
|
||||
if not self._has_agent_in_earlier_events(agent.name, current_index, events):
|
||||
filtered_agents.append(agent)
|
||||
|
||||
return filtered_agents
|
||||
|
||||
def _has_agent_in_earlier_events(
|
||||
self, agent_name: str, current_index: int, events: list[Event]
|
||||
) -> bool:
|
||||
"""Check if an agent appears in any earlier RecallObservation in the event list.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent to look for
|
||||
current_index: The index of the current event in the events list
|
||||
events: The list of all events
|
||||
|
||||
Returns:
|
||||
bool: True if the agent appears in an earlier RecallObservation, False otherwise
|
||||
"""
|
||||
for event in events[:current_index]:
|
||||
# Note that this check includes the WORKSPACE_CONTEXT
|
||||
if isinstance(event, RecallObservation):
|
||||
if any(
|
||||
agent.name == agent_name for agent in event.microagent_knowledge
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _filter_unmatched_tool_calls(
|
||||
messages: list[Message],
|
||||
) -> Generator[Message, None, None]:
|
||||
"""Filter out tool calls that don't have matching tool responses and vice versa.
|
||||
|
||||
This ensures that every tool_call_id in a tool message has a corresponding tool_calls[].id
|
||||
in an assistant message, and vice versa. The original list is unmodified, when tool_calls is
|
||||
updated the message is copied.
|
||||
|
||||
This does not remove items with id set to None.
|
||||
"""
|
||||
tool_call_ids = {
|
||||
tool_call.id
|
||||
for message in messages
|
||||
if message.tool_calls
|
||||
for tool_call in message.tool_calls
|
||||
if message.role == 'assistant' and tool_call.id
|
||||
}
|
||||
tool_response_ids = {
|
||||
message.tool_call_id
|
||||
for message in messages
|
||||
if message.role == 'tool' and message.tool_call_id
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
# Remove tool messages with no matching assistant tool call
|
||||
if message.role == 'tool' and message.tool_call_id:
|
||||
if message.tool_call_id in tool_call_ids:
|
||||
yield message
|
||||
|
||||
# Remove assistant tool calls with no matching tool response
|
||||
elif message.role == 'assistant' and message.tool_calls:
|
||||
all_tool_calls_match = all(
|
||||
tool_call.id in tool_response_ids
|
||||
for tool_call in message.tool_calls
|
||||
)
|
||||
if all_tool_calls_match:
|
||||
yield message
|
||||
else:
|
||||
matched_tool_calls = [
|
||||
tool_call
|
||||
for tool_call in message.tool_calls
|
||||
if tool_call.id in tool_response_ids
|
||||
]
|
||||
|
||||
if matched_tool_calls:
|
||||
# Keep an updated message if there are tools calls left
|
||||
yield message.model_copy(
|
||||
update={'tool_calls': matched_tool_calls}
|
||||
)
|
||||
else:
|
||||
# Any other case is kept
|
||||
yield message
|
||||
|
||||
def _ensure_system_message(self, events: list[Event]) -> None:
|
||||
"""Checks if a SystemMessageAction exists and adds one if not (for legacy compatibility)."""
|
||||
# Check if there's a SystemMessageAction in the events
|
||||
has_system_message = any(
|
||||
isinstance(event, SystemMessageAction) for event in events
|
||||
)
|
||||
|
||||
# Legacy behavior: If no SystemMessageAction is found, add one
|
||||
if not has_system_message:
|
||||
logger.debug(
|
||||
'[ConversationMemory] No SystemMessageAction found in events. '
|
||||
'Adding one for backward compatibility. '
|
||||
)
|
||||
system_prompt = self.prompt_manager.get_system_message(
|
||||
cli_mode=self.agent_config.cli_mode
|
||||
)
|
||||
if system_prompt:
|
||||
system_message = SystemMessageAction(content=system_prompt)
|
||||
# Insert the system message directly at the beginning of the events list
|
||||
events.insert(0, system_message)
|
||||
logger.info(
|
||||
'[ConversationMemory] Added SystemMessageAction for backward compatibility'
|
||||
)
|
||||
|
||||
def _ensure_initial_user_message(
|
||||
self,
|
||||
events: list[Event],
|
||||
initial_user_action: MessageAction,
|
||||
forgotten_event_ids: set[int],
|
||||
) -> None:
|
||||
"""Checks if the second event is a user MessageAction and inserts the provided one if needed.
|
||||
|
||||
IMPORTANT: If the initial user action has been condensed (its ID is in forgotten_event_ids),
|
||||
we do NOT re-insert it. This prevents old instructions from being re-executed after
|
||||
conversation condensation. The condensation summary already contains the context of
|
||||
what was requested and completed.
|
||||
|
||||
Args:
|
||||
events: The list of events to modify in-place
|
||||
initial_user_action: The initial user message action from the full history
|
||||
forgotten_event_ids: Set of event IDs that have been forgotten/condensed
|
||||
"""
|
||||
if (
|
||||
not events
|
||||
): # Should have system message from previous step, but safety check
|
||||
logger.error('Cannot ensure initial user message: event list is empty.')
|
||||
# Or raise? Let's log for now, _ensure_system_message should handle this.
|
||||
return
|
||||
|
||||
# Check if the initial user action has been condensed/forgotten.
|
||||
# If so, we should NOT re-insert it to prevent re-execution of old instructions.
|
||||
# The condensation summary already contains the context of what was requested.
|
||||
initial_user_action_id = initial_user_action.id
|
||||
if initial_user_action_id in forgotten_event_ids:
|
||||
logger.info(
|
||||
f'Initial user action (id={initial_user_action_id}) has been condensed. '
|
||||
'Not re-inserting to prevent re-execution of old instructions.'
|
||||
)
|
||||
return
|
||||
|
||||
# We expect events[0] to be SystemMessageAction after _ensure_system_message
|
||||
if len(events) == 1:
|
||||
# Only system message exists
|
||||
logger.info(
|
||||
'Initial user message action was missing. Inserting the initial user message.'
|
||||
)
|
||||
events.insert(1, initial_user_action)
|
||||
elif not isinstance(events[1], MessageAction) or events[1].source != 'user':
|
||||
# The second event exists but is not the correct initial user message action.
|
||||
# We will insert the correct one provided.
|
||||
logger.info(
|
||||
'Second event was not the initial user message action. Inserting correct one at index 1.'
|
||||
)
|
||||
|
||||
# Insert the user message event at index 1. This will be the second message as LLM APIs expect
|
||||
# but something was wrong with the history, so log all we can.
|
||||
events.insert(1, initial_user_action)
|
||||
|
||||
# Else: events[1] is already a user MessageAction.
|
||||
# Check if it matches the one provided (if any discrepancy, log warning but proceed).
|
||||
elif events[1] != initial_user_action:
|
||||
logger.debug(
|
||||
'The user MessageAction at index 1 does not match the provided initial_user_action. '
|
||||
'Proceeding with the one found in condensed history.'
|
||||
)
|
||||
@@ -1,405 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import openhands
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
from openhands.microagent import (
|
||||
BaseMicroagent,
|
||||
KnowledgeMicroagent,
|
||||
RepoMicroagent,
|
||||
load_microagents_from_dir,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.utils.prompt import (
|
||||
ConversationInstructions,
|
||||
RepositoryInfo,
|
||||
RuntimeInfo,
|
||||
)
|
||||
|
||||
GLOBAL_MICROAGENTS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'skills',
|
||||
)
|
||||
|
||||
USER_MICROAGENTS_DIR = Path.home() / '.openhands' / 'microagents'
|
||||
|
||||
|
||||
class Memory:
|
||||
"""Memory is a component that listens to the EventStream for information retrieval actions
|
||||
(a RecallAction) and publishes observations with the content (such as RecallObservation).
|
||||
"""
|
||||
|
||||
sid: str
|
||||
event_stream: EventStream
|
||||
status_callback: Callable | None
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
repo_microagents: dict[str, RepoMicroagent]
|
||||
knowledge_microagents: dict[str, KnowledgeMicroagent]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_stream: EventStream,
|
||||
sid: str,
|
||||
status_callback: Callable | None = None,
|
||||
):
|
||||
self.event_stream = event_stream
|
||||
self.sid = sid if sid else str(uuid.uuid4())
|
||||
self.status_callback = status_callback
|
||||
self.loop = None
|
||||
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY,
|
||||
self.on_event,
|
||||
self.sid,
|
||||
)
|
||||
|
||||
# Additional placeholders to store user workspace microagents
|
||||
self.repo_microagents = {}
|
||||
self.knowledge_microagents = {}
|
||||
|
||||
# Store repository / runtime info to send them to the templating later
|
||||
self.repository_info: RepositoryInfo | None = None
|
||||
self.runtime_info: RuntimeInfo | None = None
|
||||
self.conversation_instructions: ConversationInstructions | None = None
|
||||
|
||||
# Load global microagents (Knowledge + Repo)
|
||||
# from typically OpenHands/skills (i.e., the PUBLIC microagents)
|
||||
self._load_global_microagents()
|
||||
|
||||
# Load user microagents from ~/.openhands/microagents/
|
||||
self._load_user_microagents()
|
||||
|
||||
def on_event(self, event: Event):
|
||||
"""Handle an event from the event stream."""
|
||||
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||
|
||||
async def _on_event(self, event: Event):
|
||||
"""Handle an event from the event stream asynchronously."""
|
||||
try:
|
||||
if isinstance(event, RecallAction):
|
||||
# if this is a workspace context recall (on first user message)
|
||||
# create and add a RecallObservation
|
||||
# with info about repo, runtime, instructions, etc. including microagent knowledge if any
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and event.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
):
|
||||
logger.debug('Workspace context recall')
|
||||
workspace_obs: RecallObservation | NullObservation | None = None
|
||||
|
||||
workspace_obs = self._on_workspace_context_recall(event)
|
||||
if workspace_obs is None:
|
||||
workspace_obs = NullObservation(content='')
|
||||
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
workspace_obs._cause = event.id # type: ignore[union-attr]
|
||||
|
||||
self.event_stream.add_event(workspace_obs, EventSource.ENVIRONMENT)
|
||||
return
|
||||
|
||||
# Handle knowledge recall (triggered microagents)
|
||||
# Allow triggering from both user and agent messages
|
||||
elif (
|
||||
event.source == EventSource.USER
|
||||
or event.source == EventSource.AGENT
|
||||
) and event.recall_type == RecallType.KNOWLEDGE:
|
||||
logger.debug(
|
||||
f'Microagent knowledge recall from {event.source} message'
|
||||
)
|
||||
microagent_obs: RecallObservation | NullObservation | None = None
|
||||
microagent_obs = self._on_microagent_recall(event)
|
||||
if microagent_obs is None:
|
||||
microagent_obs = NullObservation(content='')
|
||||
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
microagent_obs._cause = event.id # type: ignore[union-attr]
|
||||
|
||||
self.event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
return
|
||||
except Exception as e:
|
||||
error_str = f'Error: {str(e.__class__.__name__)}'
|
||||
logger.error(error_str)
|
||||
self.set_runtime_status(RuntimeStatus.ERROR_MEMORY, error_str)
|
||||
return
|
||||
|
||||
def _on_workspace_context_recall(
|
||||
self, event: RecallAction
|
||||
) -> RecallObservation | None:
|
||||
"""Add repository and runtime information to the stream as a RecallObservation.
|
||||
|
||||
This method collects information from all available repo microagents and concatenates their contents.
|
||||
Multiple repo microagents are supported, and their contents will be concatenated with newlines between them.
|
||||
"""
|
||||
# Create WORKSPACE_CONTEXT info:
|
||||
# - repository_info
|
||||
# - runtime_info
|
||||
# - repository_instructions
|
||||
# - microagent_knowledge
|
||||
|
||||
# Collect raw repository instructions
|
||||
repo_instructions = ''
|
||||
|
||||
# Retrieve the context of repo instructions from all repo microagents
|
||||
for microagent in self.repo_microagents.values():
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
# Find any matched microagents based on the query
|
||||
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||
|
||||
# Create observation if we have anything
|
||||
if (
|
||||
self.repository_info
|
||||
or self.runtime_info
|
||||
or repo_instructions
|
||||
or microagent_knowledge
|
||||
or self.conversation_instructions
|
||||
):
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name=(
|
||||
self.repository_info.repo_name
|
||||
if self.repository_info
|
||||
and self.repository_info.repo_name is not None
|
||||
else ''
|
||||
),
|
||||
repo_directory=(
|
||||
self.repository_info.repo_directory
|
||||
if self.repository_info
|
||||
and self.repository_info.repo_directory is not None
|
||||
else ''
|
||||
),
|
||||
repo_branch=(
|
||||
self.repository_info.branch_name
|
||||
if self.repository_info
|
||||
and self.repository_info.branch_name is not None
|
||||
else ''
|
||||
),
|
||||
repo_instructions=repo_instructions if repo_instructions else '',
|
||||
runtime_hosts=(
|
||||
self.runtime_info.available_hosts
|
||||
if self.runtime_info
|
||||
and self.runtime_info.available_hosts is not None
|
||||
else {}
|
||||
),
|
||||
additional_agent_instructions=(
|
||||
self.runtime_info.additional_agent_instructions
|
||||
if self.runtime_info
|
||||
and self.runtime_info.additional_agent_instructions is not None
|
||||
else ''
|
||||
),
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Added workspace context',
|
||||
date=self.runtime_info.date if self.runtime_info is not None else '',
|
||||
custom_secrets_descriptions=(
|
||||
self.runtime_info.custom_secrets_descriptions
|
||||
if self.runtime_info is not None
|
||||
else {}
|
||||
),
|
||||
conversation_instructions=(
|
||||
self.conversation_instructions.content
|
||||
if self.conversation_instructions is not None
|
||||
else ''
|
||||
),
|
||||
working_dir=self.runtime_info.working_dir if self.runtime_info else '',
|
||||
)
|
||||
return obs
|
||||
return None
|
||||
|
||||
def _on_microagent_recall(
|
||||
self,
|
||||
event: RecallAction,
|
||||
) -> RecallObservation | None:
|
||||
"""When a microagent action triggers microagents, create a RecallObservation with structured data."""
|
||||
# Find any matched microagents based on the query
|
||||
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||
|
||||
# Create observation if we have anything
|
||||
if microagent_knowledge:
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
return obs
|
||||
return None
|
||||
|
||||
def _find_microagent_knowledge(self, query: str) -> list[MicroagentKnowledge]:
|
||||
"""Find microagent knowledge based on a query.
|
||||
|
||||
Args:
|
||||
query: The query to search for microagent triggers
|
||||
|
||||
Returns:
|
||||
A list of MicroagentKnowledge objects for matched triggers
|
||||
"""
|
||||
recalled_content: list[MicroagentKnowledge] = []
|
||||
|
||||
# skip empty queries
|
||||
if not query:
|
||||
return recalled_content
|
||||
|
||||
# Search for microagent triggers in the query
|
||||
for name, microagent in self.knowledge_microagents.items():
|
||||
trigger = microagent.match_trigger(query)
|
||||
if trigger:
|
||||
logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger)
|
||||
recalled_content.append(
|
||||
MicroagentKnowledge(
|
||||
name=microagent.name,
|
||||
trigger=trigger,
|
||||
content=microagent.content,
|
||||
)
|
||||
)
|
||||
return recalled_content
|
||||
|
||||
def load_user_workspace_microagents(
|
||||
self, user_microagents: list[BaseMicroagent]
|
||||
) -> None:
|
||||
"""This method loads microagents from a user's cloned repo or workspace directory.
|
||||
|
||||
This is typically called from agent_session or setup once the workspace is cloned.
|
||||
"""
|
||||
logger.info(
|
||||
'Loading user workspace microagents: %s', [m.name for m in user_microagents]
|
||||
)
|
||||
for user_microagent in user_microagents:
|
||||
if isinstance(user_microagent, KnowledgeMicroagent):
|
||||
self.knowledge_microagents[user_microagent.name] = user_microagent
|
||||
elif isinstance(user_microagent, RepoMicroagent):
|
||||
self.repo_microagents[user_microagent.name] = user_microagent
|
||||
|
||||
def _load_global_microagents(self) -> None:
|
||||
"""Loads microagents from the global microagents_dir"""
|
||||
repo_agents, knowledge_agents = load_microagents_from_dir(
|
||||
GLOBAL_MICROAGENTS_DIR
|
||||
)
|
||||
for name, agent_knowledge in knowledge_agents.items():
|
||||
self.knowledge_microagents[name] = agent_knowledge
|
||||
for name, agent_repo in repo_agents.items():
|
||||
self.repo_microagents[name] = agent_repo
|
||||
|
||||
def _load_user_microagents(self) -> None:
|
||||
"""Loads microagents from the user's home directory (~/.openhands/microagents/)
|
||||
Creates the directory if it doesn't exist.
|
||||
"""
|
||||
try:
|
||||
# Create the user microagents directory if it doesn't exist
|
||||
os.makedirs(USER_MICROAGENTS_DIR, exist_ok=True)
|
||||
|
||||
# Load microagents from user directory
|
||||
repo_agents, knowledge_agents = load_microagents_from_dir(
|
||||
USER_MICROAGENTS_DIR
|
||||
)
|
||||
|
||||
for name, agent_knowledge in knowledge_agents.items():
|
||||
self.knowledge_microagents[name] = agent_knowledge
|
||||
for name, agent_repo in repo_agents.items():
|
||||
self.repo_microagents[name] = agent_repo
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed to load user microagents from {USER_MICROAGENTS_DIR}: {str(e)}'
|
||||
)
|
||||
|
||||
def get_microagent_mcp_tools(self) -> list[MCPConfig]:
|
||||
"""Get MCP tools from all repo microagents (always active)
|
||||
|
||||
Returns:
|
||||
A list of MCP tools configurations from microagents
|
||||
"""
|
||||
mcp_configs: list[MCPConfig] = []
|
||||
|
||||
# Check all repo microagents for MCP tools (always active)
|
||||
for agent in self.repo_microagents.values():
|
||||
if agent.metadata.mcp_tools:
|
||||
mcp_configs.append(agent.metadata.mcp_tools)
|
||||
logger.debug(
|
||||
f'Found MCP tools in repo microagent {agent.name}: {agent.metadata.mcp_tools}'
|
||||
)
|
||||
|
||||
return mcp_configs
|
||||
|
||||
def set_repository_info(
|
||||
self, repo_name: str, repo_directory: str, branch_name: str | None = None
|
||||
) -> None:
|
||||
"""Store repository info so we can reference it in an observation."""
|
||||
if repo_name or repo_directory:
|
||||
self.repository_info = RepositoryInfo(
|
||||
repo_name, repo_directory, branch_name
|
||||
)
|
||||
else:
|
||||
self.repository_info = None
|
||||
|
||||
def set_runtime_info(
|
||||
self,
|
||||
runtime: Runtime,
|
||||
custom_secrets_descriptions: dict[str, str],
|
||||
working_dir: str,
|
||||
) -> None:
|
||||
"""Store runtime info (web hosts, ports, etc.)."""
|
||||
# e.g. { '127.0.0.1': 8080 }
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
date = str(utc_now.date())
|
||||
|
||||
if runtime.web_hosts or runtime.additional_agent_instructions:
|
||||
self.runtime_info = RuntimeInfo(
|
||||
available_hosts=runtime.web_hosts,
|
||||
additional_agent_instructions=runtime.additional_agent_instructions,
|
||||
date=date,
|
||||
custom_secrets_descriptions=custom_secrets_descriptions,
|
||||
working_dir=working_dir,
|
||||
)
|
||||
else:
|
||||
self.runtime_info = RuntimeInfo(
|
||||
date=date,
|
||||
custom_secrets_descriptions=custom_secrets_descriptions,
|
||||
working_dir=working_dir,
|
||||
)
|
||||
|
||||
def set_conversation_instructions(
|
||||
self, conversation_instructions: str | None
|
||||
) -> None:
|
||||
"""Set contextual information for conversation
|
||||
This is information the agent may require
|
||||
"""
|
||||
self.conversation_instructions = ConversationInstructions(
|
||||
content=conversation_instructions or ''
|
||||
)
|
||||
|
||||
def set_runtime_status(self, status: RuntimeStatus, message: str):
|
||||
"""Sends an error message if the callback function was provided."""
|
||||
if self.status_callback:
|
||||
try:
|
||||
if self.loop is None:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._set_runtime_status('error', status, message), self.loop
|
||||
)
|
||||
except (RuntimeError, KeyError) as e:
|
||||
logger.error(
|
||||
f'Error sending status message: {e.__class__.__name__}',
|
||||
stack_info=False,
|
||||
)
|
||||
|
||||
async def _set_runtime_status(
|
||||
self, msg_type: str, runtime_status: RuntimeStatus, message: str
|
||||
):
|
||||
"""Sends a status message to the client."""
|
||||
if self.status_callback:
|
||||
self.status_callback(msg_type, runtime_status, message)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user