mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
search_eng
...
debug-visu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51d1a993fb |
6
.github/workflows/deploy-docs.yml
vendored
6
.github/workflows/deploy-docs.yml
vendored
@@ -11,7 +11,6 @@ on:
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- '.github/workflows/deploy-docs.yml'
|
||||
- 'pydoc-markdown.yml'
|
||||
branches:
|
||||
- main
|
||||
|
||||
@@ -40,10 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Generate Python Docs
|
||||
run: |
|
||||
rm -rf docs/modules/python
|
||||
pip install pydoc-markdown
|
||||
pydoc-markdown
|
||||
run: rm -rf docs/modules/python && pip install pydoc-markdown && pydoc-markdown
|
||||
- name: Install dependencies
|
||||
run: cd docs && npm ci
|
||||
- name: Build website
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||
<br/>
|
||||
@@ -96,7 +96,7 @@ troubleshooting resources, and advanced configuration options.
|
||||
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
|
||||
through Slack, so this is the best place to start, but we also are happy to have you contact us on Discord or Github:
|
||||
|
||||
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw) - Here we talk about research, architecture, and future development.
|
||||
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg) - Here we talk about research, architecture, and future development.
|
||||
- [Join our Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
|
||||
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ Explorez le code source d'OpenHands sur [GitHub](https://github.com/All-Hands-AI
|
||||
/>
|
||||
</a>
|
||||
<br></br>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
|
||||
<img
|
||||
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
||||
alt="Join our Slack community"
|
||||
|
||||
@@ -42,7 +42,7 @@ OpenHands 是一个**自主 AI 软件工程师**,能够执行复杂的工程
|
||||
/>
|
||||
</a>
|
||||
<br></br>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
|
||||
<img
|
||||
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
||||
alt="Join our Slack community"
|
||||
|
||||
@@ -308,11 +308,6 @@ The agent configuration options are defined in the `[agent]` and `[agent.<agent_
|
||||
- Default: `false`
|
||||
- Description: Whether Jupyter is enabled in the action space
|
||||
|
||||
- `enable_search_engine`
|
||||
- Type: `bool`
|
||||
- Default: `false`
|
||||
- Description: Whether the search engine tool is enabled in the action space. See [Search Configuration](./search/search-configuration.md) for details.
|
||||
|
||||
- `enable_history_truncation`
|
||||
- Type: `bool`
|
||||
- Default: `true`
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
# Search Configuration
|
||||
|
||||
OpenHands provides a search engine capability that allows agents to perform web searches using the Brave Search API. This guide explains how to configure and use the search feature.
|
||||
|
||||
## Overview
|
||||
|
||||
The search engine feature enables agents to:
|
||||
- Execute web search queries programmatically
|
||||
- Get structured results including web pages, news, videos, and FAQs
|
||||
- Avoid CAPTCHA challenges that often occur when using browser-based search
|
||||
|
||||
## Configuration
|
||||
|
||||
### Enabling Search
|
||||
|
||||
To enable the search engine feature, set the following in your `config.toml`:
|
||||
|
||||
```toml
|
||||
[agent]
|
||||
enable_search_engine = true
|
||||
```
|
||||
|
||||
Or when using Docker, set the environment variable:
|
||||
```bash
|
||||
-e AGENT_ENABLE_SEARCH_ENGINE=true
|
||||
```
|
||||
|
||||
### API Key Setup
|
||||
|
||||
The search feature requires a Brave Search API key. You can obtain one from the [Brave Search API Dashboard](https://api.search.brave.com/app/keys).
|
||||
|
||||
Set the API key in your `config.toml`:
|
||||
```toml
|
||||
[search]
|
||||
enabled = true
|
||||
api_key = "your-api-key-here"
|
||||
```
|
||||
|
||||
Or when using Docker:
|
||||
```bash
|
||||
-e SEARCH_ENABLED=true
|
||||
-e SEARCH_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
## Search Results
|
||||
|
||||
When a search is performed, the results are returned in a structured format that includes:
|
||||
|
||||
- Web search results
|
||||
- News articles
|
||||
- Video content
|
||||
- FAQ entries
|
||||
- Discussion threads
|
||||
- Infoboxes (when available)
|
||||
- Location information (when relevant)
|
||||
|
||||
Each result type includes:
|
||||
- Title
|
||||
- URL (when applicable)
|
||||
- Description or snippet
|
||||
- Additional metadata specific to the result type
|
||||
|
||||
## Usage Example
|
||||
|
||||
When the search feature is enabled, agents can use the `search_engine` tool to perform searches. For example:
|
||||
|
||||
```python
|
||||
# The agent can make a tool call like this:
|
||||
{
|
||||
"name": "search_engine",
|
||||
"arguments": {
|
||||
"query": "latest developments in AI"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The search results will be returned in a markdown-formatted structure that's easy for the agent to parse and understand.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Query Formulation**
|
||||
- Keep queries focused and specific
|
||||
- Include relevant keywords
|
||||
- Avoid overly complex or compound queries
|
||||
|
||||
2. **Rate Limiting**
|
||||
- Be mindful of API rate limits
|
||||
- Cache results when appropriate
|
||||
- Implement retries with exponential backoff for failed requests
|
||||
|
||||
3. **Error Handling**
|
||||
- Handle API errors gracefully
|
||||
- Provide meaningful feedback when searches fail
|
||||
- Have fallback strategies when search is unavailable
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Common issues and solutions:
|
||||
|
||||
1. **Search Not Working**
|
||||
- Verify `enable_search_engine` is set to `true`
|
||||
- Confirm the Brave API key is correctly set
|
||||
- Check API key permissions and quotas
|
||||
|
||||
2. **No Results**
|
||||
- Verify the query is not empty
|
||||
- Try reformulating the search query
|
||||
- Check for any API response errors
|
||||
|
||||
3. **Rate Limiting**
|
||||
- Monitor API usage
|
||||
- Implement caching if needed
|
||||
- Consider upgrading API tier if limits are consistently hit
|
||||
@@ -8,7 +8,7 @@ function CustomFooter() {
|
||||
<footer className="custom-footer">
|
||||
<div className="footer-content">
|
||||
<div className="footer-icons">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw" target="_blank" rel="noopener noreferrer">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg" target="_blank" rel="noopener noreferrer">
|
||||
<FaSlack />
|
||||
</a>
|
||||
<a href="https://discord.gg/ESHStjSjD4" target="_blank" rel="noopener noreferrer">
|
||||
|
||||
@@ -46,7 +46,7 @@ export function HomepageHeader() {
|
||||
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" /></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License" /></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community" /></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits" /></a>
|
||||
<br/>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
@@ -176,11 +175,6 @@ def process_instance(
|
||||
logger.warning(
|
||||
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
||||
)
|
||||
metadata = copy.deepcopy(metadata)
|
||||
metadata.details['runtime_failure_count'] = runtime_failure_count
|
||||
metadata.details['remote_runtime_resource_factor'] = (
|
||||
config.sandbox.remote_runtime_resource_factor
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = create_runtime(config)
|
||||
@@ -302,20 +296,14 @@ def process_instance(
|
||||
with open(test_output_path, 'w') as f:
|
||||
f.write(test_output)
|
||||
try:
|
||||
extra_kwargs = {}
|
||||
if 'SWE-Gym' in metadata.dataset:
|
||||
# SWE-Gym uses a different version of the package, hence a different eval report argument
|
||||
extra_kwargs['log_path'] = test_output_path
|
||||
else:
|
||||
extra_kwargs['test_log_path'] = test_output_path
|
||||
_report = conditional_imports.get_eval_report(
|
||||
test_spec=test_spec,
|
||||
prediction={
|
||||
'model_patch': model_patch,
|
||||
'instance_id': instance_id,
|
||||
},
|
||||
test_log_path=test_output_path,
|
||||
include_tests_status=True,
|
||||
**extra_kwargs,
|
||||
)
|
||||
report = _report[instance_id]
|
||||
logger.info(
|
||||
@@ -475,7 +463,6 @@ if __name__ == '__main__':
|
||||
.decode('utf-8')
|
||||
.strip(), # Current commit
|
||||
dataset=args.dataset, # Dataset name from args
|
||||
details={},
|
||||
)
|
||||
|
||||
# The evaluation harness constrains the signature of `process_instance_func` but we need to
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,7 @@ def get_resource_mapping(dataset_name: str) -> dict[str, float]:
|
||||
if dataset_name not in _global_resource_mapping:
|
||||
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
|
||||
if not os.path.exists(file_path):
|
||||
logger.info(f'Resource mapping for {dataset_name} not found.')
|
||||
logger.warning(f'Resource mapping for {dataset_name} not found.')
|
||||
return None
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@@ -150,8 +149,7 @@ def get_config(
|
||||
) -> AppConfig:
|
||||
# We use a different instance image for the each instance of swe-bench eval
|
||||
use_official_image = bool(
|
||||
('verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower())
|
||||
and 'swe-gym' not in metadata.dataset.lower()
|
||||
'verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower()
|
||||
)
|
||||
base_container_image = get_instance_docker_image(
|
||||
instance['instance_id'], use_official_image
|
||||
@@ -477,13 +475,6 @@ def process_instance(
|
||||
logger.warning(
|
||||
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
||||
)
|
||||
|
||||
metadata = copy.deepcopy(metadata)
|
||||
metadata.details['runtime_failure_count'] = runtime_failure_count
|
||||
metadata.details['remote_runtime_resource_factor'] = (
|
||||
config.sandbox.remote_runtime_resource_factor
|
||||
)
|
||||
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
@@ -569,6 +560,20 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
|
||||
return dataset
|
||||
|
||||
|
||||
# A list of instances that are known to be tricky to infer
|
||||
# (will cause runtime failure even with resource factor = 8)
|
||||
SWEGYM_EXCLUDE_IDS = [
|
||||
'dask__dask-10422',
|
||||
'pandas-dev__pandas-50548',
|
||||
'pandas-dev__pandas-53672',
|
||||
'pandas-dev__pandas-54174',
|
||||
'pandas-dev__pandas-55518',
|
||||
'pandas-dev__pandas-58383',
|
||||
'pydata__xarray-6721',
|
||||
'pytest-dev__pytest-10081',
|
||||
'pytest-dev__pytest-7236',
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
@@ -593,20 +598,11 @@ if __name__ == '__main__':
|
||||
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
||||
)
|
||||
if 'SWE-Gym' in args.dataset:
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'split',
|
||||
'swegym_verified_instances.json',
|
||||
),
|
||||
'r',
|
||||
) as f:
|
||||
swegym_verified_instances = json.load(f)
|
||||
swe_bench_tests = swe_bench_tests[
|
||||
swe_bench_tests['instance_id'].isin(swegym_verified_instances)
|
||||
]
|
||||
swe_bench_tests = swe_bench_tests[
|
||||
~swe_bench_tests['instance_id'].isin(SWEGYM_EXCLUDE_IDS)
|
||||
]
|
||||
logger.info(
|
||||
f'{len(swe_bench_tests)} tasks left after filtering for SWE-Gym verified instances'
|
||||
f'{len(swe_bench_tests)} tasks left after excluding SWE-Gym excluded tasks'
|
||||
)
|
||||
|
||||
llm_config = None
|
||||
|
||||
@@ -9,7 +9,7 @@ parser.add_argument(
|
||||
'--dataset_name',
|
||||
type=str,
|
||||
help='Name of the dataset to download',
|
||||
default='princeton-nlp/SWE-bench_Verified',
|
||||
default='princeton-nlp/SWE-bench_Lite',
|
||||
)
|
||||
parser.add_argument('--split', type=str, help='Split to download', default='test')
|
||||
args = parser.parse_args()
|
||||
@@ -20,12 +20,7 @@ print(
|
||||
f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}'
|
||||
)
|
||||
patches = [
|
||||
{
|
||||
'instance_id': row['instance_id'],
|
||||
'model_patch': row['patch'],
|
||||
'model_name_or_path': 'gold',
|
||||
}
|
||||
for row in dataset
|
||||
{'instance_id': row['instance_id'], 'model_patch': row['patch']} for row in dataset
|
||||
]
|
||||
print(f'{len(patches)} gold patches loaded')
|
||||
pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,7 @@ from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
FAKE_RESPONSES = {
|
||||
'CodeActAgent': fake_user_response,
|
||||
'DelegatorAgent': fake_user_response,
|
||||
'VisualBrowsingAgent': fake_user_response,
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ load_dotenv()
|
||||
from openhands.agenthub import ( # noqa: E402
|
||||
browsing_agent,
|
||||
codeact_agent,
|
||||
delegator_agent,
|
||||
dummy_agent,
|
||||
visualbrowsing_agent,
|
||||
)
|
||||
@@ -14,6 +15,7 @@ from openhands.controller.agent import Agent # noqa: E402
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'codeact_agent',
|
||||
'delegator_agent',
|
||||
'dummy_agent',
|
||||
'browsing_agent',
|
||||
'visualbrowsing_agent',
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import openhands
|
||||
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
@@ -70,18 +72,23 @@ class CodeActAgent(Agent):
|
||||
codeact_enable_browsing=self.config.codeact_enable_browsing,
|
||||
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
|
||||
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
|
||||
codeact_enable_search_engine=self.config.enable_search_engine,
|
||||
llm=self.llm,
|
||||
)
|
||||
logger.debug(
|
||||
f"TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}"
|
||||
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
|
||||
)
|
||||
self.prompt_manager = PromptManager(
|
||||
microagent_dir=os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'microagents',
|
||||
)
|
||||
if self.config.enable_prompt_extensions
|
||||
else None,
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
disabled_microagents=self.config.disabled_microagents,
|
||||
)
|
||||
|
||||
# Create a ConversationMemory instance
|
||||
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
|
||||
self.conversation_memory = ConversationMemory(self.prompt_manager)
|
||||
|
||||
self.condenser = Condenser.from_config(self.config.condenser)
|
||||
logger.debug(f'Using condenser: {type(self.condenser)}')
|
||||
@@ -161,7 +168,7 @@ class CodeActAgent(Agent):
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
# Use ConversationMemory to process initial messages
|
||||
# Use conversation_memory to process events instead of calling events_to_messages directly
|
||||
messages = self.conversation_memory.process_initial_messages(
|
||||
with_caching=self.llm.is_caching_prompt_active()
|
||||
)
|
||||
@@ -173,12 +180,12 @@ class CodeActAgent(Agent):
|
||||
f'Processing {len(events)} events from a total of {len(state.history)} events'
|
||||
)
|
||||
|
||||
# Use ConversationMemory to process events
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_messages=messages,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
enable_som_visual_browsing=self.config.enable_som_visual_browsing,
|
||||
)
|
||||
|
||||
messages = self._enhance_messages(messages)
|
||||
@@ -209,7 +216,14 @@ class CodeActAgent(Agent):
|
||||
# compose the first user message with examples
|
||||
self.prompt_manager.add_examples_to_initial_message(msg)
|
||||
|
||||
elif msg.role == 'user':
|
||||
# and/or repo/runtime info
|
||||
if self.config.enable_prompt_extensions:
|
||||
self.prompt_manager.add_info_to_initial_message(msg)
|
||||
|
||||
# enhance the user message with additional context based on keywords matched
|
||||
if msg.role == 'user':
|
||||
self.prompt_manager.enhance_message(msg)
|
||||
|
||||
# Add double newline between consecutive user messages
|
||||
if prev_role == 'user' and len(msg.content) > 0:
|
||||
# Find the first TextContent in the message to add newlines
|
||||
|
||||
@@ -12,14 +12,13 @@ from litellm import (
|
||||
|
||||
from openhands.agenthub.codeact_agent.tools import (
|
||||
BrowserTool,
|
||||
CmdRunTool,
|
||||
FinishTool,
|
||||
IPythonTool,
|
||||
LLMBasedFileEditTool,
|
||||
SearchEngineTool,
|
||||
StrReplaceEditorTool,
|
||||
ThinkTool,
|
||||
WebReadTool,
|
||||
create_cmd_run_tool,
|
||||
create_str_replace_editor_tool,
|
||||
)
|
||||
from openhands.core.exceptions import (
|
||||
FunctionCallNotExistsError,
|
||||
@@ -37,11 +36,9 @@ from openhands.events.action import (
|
||||
FileReadAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
SearchAction,
|
||||
)
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm import LLM
|
||||
|
||||
|
||||
def combine_thought(action: Action, thought: str) -> Action:
|
||||
@@ -83,7 +80,7 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
# CmdRunTool (Bash)
|
||||
# ================================================
|
||||
|
||||
if tool_call.function.name == create_cmd_run_tool()['function']['name']:
|
||||
if tool_call.function.name == CmdRunTool['function']['name']:
|
||||
if 'command' not in arguments:
|
||||
raise FunctionCallValidationError(
|
||||
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
||||
@@ -134,10 +131,7 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
start=arguments.get('start', 1),
|
||||
end=arguments.get('end', -1),
|
||||
)
|
||||
elif (
|
||||
tool_call.function.name
|
||||
== create_str_replace_editor_tool()['function']['name']
|
||||
):
|
||||
elif tool_call.function.name == StrReplaceEditorTool['function']['name']:
|
||||
if 'command' not in arguments:
|
||||
raise FunctionCallValidationError(
|
||||
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
||||
@@ -193,15 +187,6 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
f'Missing required argument "url" in tool call {tool_call.function.name}'
|
||||
)
|
||||
action = BrowseURLAction(url=arguments['url'])
|
||||
# ================================================
|
||||
# SearchEngineTool (search the web using text queries)
|
||||
# ================================================
|
||||
elif tool_call.function.name == SearchEngineTool['function']['name']:
|
||||
if 'query' not in arguments:
|
||||
raise FunctionCallNotExistsError(
|
||||
f'Missing required argument "query" in tool call {tool_call.function.name}'
|
||||
)
|
||||
action = SearchAction(query=arguments['query'])
|
||||
else:
|
||||
raise FunctionCallNotExistsError(
|
||||
f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.'
|
||||
@@ -234,25 +219,8 @@ def get_tools(
|
||||
codeact_enable_browsing: bool = False,
|
||||
codeact_enable_llm_editor: bool = False,
|
||||
codeact_enable_jupyter: bool = False,
|
||||
codeact_enable_search_engine: bool = False,
|
||||
llm: LLM | None = None,
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS = ['gpt-', 'o3', 'o1']
|
||||
|
||||
use_simplified_tool_desc = False
|
||||
if llm is not None:
|
||||
use_simplified_tool_desc = any(
|
||||
model_substr in llm.config.model
|
||||
for model_substr in SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS
|
||||
)
|
||||
|
||||
tools = [
|
||||
create_cmd_run_tool(use_simplified_description=use_simplified_tool_desc),
|
||||
ThinkTool,
|
||||
FinishTool,
|
||||
]
|
||||
if codeact_enable_search_engine:
|
||||
tools.append(SearchEngineTool)
|
||||
tools = [CmdRunTool, ThinkTool, FinishTool]
|
||||
if codeact_enable_browsing:
|
||||
tools.append(WebReadTool)
|
||||
tools.append(BrowserTool)
|
||||
@@ -261,9 +229,5 @@ def get_tools(
|
||||
if codeact_enable_llm_editor:
|
||||
tools.append(LLMBasedFileEditTool)
|
||||
else:
|
||||
tools.append(
|
||||
create_str_replace_editor_tool(
|
||||
use_simplified_description=use_simplified_tool_desc
|
||||
)
|
||||
)
|
||||
tools.append(StrReplaceEditorTool)
|
||||
return tools
|
||||
|
||||
@@ -20,8 +20,6 @@ When starting a web server, use the corresponding ports. You should also
|
||||
set any options to allow iframes and CORS requests, and allow the server to
|
||||
be accessed from any host (e.g. 0.0.0.0).
|
||||
{% endif %}
|
||||
{% if runtime_info.additional_agent_instructions %}
|
||||
{{ runtime_info.additional_agent_instructions }}
|
||||
{% endif %}
|
||||
</RUNTIME_INFORMATION>
|
||||
{% endif %}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
{{ agent_info.content }}
|
||||
{{ agent_info.agent.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
from .bash import create_cmd_run_tool
|
||||
from .bash import CmdRunTool
|
||||
from .browser import BrowserTool
|
||||
from .finish import FinishTool
|
||||
from .ipython import IPythonTool
|
||||
from .llm_based_edit import LLMBasedFileEditTool
|
||||
from .search_engine import SearchEngineTool
|
||||
from .str_replace_editor import create_str_replace_editor_tool
|
||||
from .str_replace_editor import StrReplaceEditorTool
|
||||
from .think import ThinkTool
|
||||
from .web_read import WebReadTool
|
||||
|
||||
__all__ = [
|
||||
'BrowserTool',
|
||||
'create_cmd_run_tool',
|
||||
'CmdRunTool',
|
||||
'FinishTool',
|
||||
'IPythonTool',
|
||||
'LLMBasedFileEditTool',
|
||||
'SearchEngineTool',
|
||||
'create_str_replace_editor_tool',
|
||||
'StrReplaceEditorTool',
|
||||
'WebReadTool',
|
||||
'ThinkTool',
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
||||
|
||||
_DETAILED_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
|
||||
_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
|
||||
|
||||
### Command Execution
|
||||
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, use `&&` or `;` to chain them together.
|
||||
@@ -22,39 +22,25 @@ _DETAILED_BASH_DESCRIPTION = """Execute a bash command in the terminal within a
|
||||
* Output truncation: If the output exceeds a maximum length, it will be truncated before being returned.
|
||||
"""
|
||||
|
||||
_SIMPLIFIED_BASH_DESCRIPTION = """Execute a bash command in the terminal.
|
||||
* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
|
||||
* Interact with running process: If a bash command returns exit code `-1`, this means the process is not yet finished. By setting `is_input` to `true`, the assistant can interact with the running process and send empty `command` to retrieve any additional logs, or send additional text (set `command` to the text) to STDIN of the running process, or send command like `C-c` (Ctrl+C), `C-d` (Ctrl+D), `C-z` (Ctrl+Z) to interrupt the process.
|
||||
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together."""
|
||||
|
||||
|
||||
def create_cmd_run_tool(
|
||||
use_simplified_description: bool = False,
|
||||
) -> ChatCompletionToolParam:
|
||||
description = (
|
||||
_SIMPLIFIED_BASH_DESCRIPTION
|
||||
if use_simplified_description
|
||||
else _DETAILED_BASH_DESCRIPTION
|
||||
)
|
||||
return ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='execute_bash',
|
||||
description=description,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
|
||||
},
|
||||
'is_input': {
|
||||
'type': 'string',
|
||||
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
|
||||
'enum': ['true', 'false'],
|
||||
},
|
||||
CmdRunTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='execute_bash',
|
||||
description=_BASH_DESCRIPTION,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
|
||||
},
|
||||
'is_input': {
|
||||
'type': 'string',
|
||||
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
|
||||
'enum': ['true', 'false'],
|
||||
},
|
||||
'required': ['command'],
|
||||
},
|
||||
),
|
||||
)
|
||||
'required': ['command'],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
||||
|
||||
_SEARCH_ENGINE_DESCRIPTION = """Execute a web search query (similar to Google search).
|
||||
|
||||
NOTE: When you need to search for information online, please use the `search_engine` tool rather than the `browser` or `web_read` tools. The `search_engine` tool connects directly to a search engine, which will help avoid CAPTCHA challenges that would otherwise block your access.
|
||||
"""
|
||||
|
||||
SearchEngineTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='search_engine',
|
||||
description=_SEARCH_ENGINE_DESCRIPTION,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'query': {
|
||||
'type': 'string',
|
||||
'description': 'The web search query (must be a non-empty string).',
|
||||
},
|
||||
},
|
||||
'required': ['query'],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
||||
|
||||
_DETAILED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||
* State is persistent across command calls and discussions with the user
|
||||
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||
* The `create` command cannot be used if the specified `path` already exists as a file
|
||||
@@ -31,73 +31,46 @@ CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
|
||||
"""
|
||||
|
||||
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||
* State is persistent across command calls and discussions with the user
|
||||
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||
* The `create` command cannot be used if the specified `path` already exists as a file
|
||||
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
|
||||
* The `undo_edit` command will revert the last edit made to the file at `path`
|
||||
Notes for using the `str_replace` command:
|
||||
* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!
|
||||
* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique
|
||||
* The `new_str` parameter should contain the edited lines that should replace the `old_str`
|
||||
"""
|
||||
|
||||
|
||||
def create_str_replace_editor_tool(
|
||||
use_simplified_description: bool = False,
|
||||
) -> ChatCompletionToolParam:
|
||||
description = (
|
||||
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION
|
||||
if use_simplified_description
|
||||
else _DETAILED_STR_REPLACE_EDITOR_DESCRIPTION
|
||||
)
|
||||
return ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='str_replace_editor',
|
||||
description=description,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
|
||||
'enum': [
|
||||
'view',
|
||||
'create',
|
||||
'str_replace',
|
||||
'insert',
|
||||
'undo_edit',
|
||||
],
|
||||
'type': 'string',
|
||||
},
|
||||
'path': {
|
||||
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
|
||||
'type': 'string',
|
||||
},
|
||||
'file_text': {
|
||||
'description': 'Required parameter of `create` command, with the content of the file to be created.',
|
||||
'type': 'string',
|
||||
},
|
||||
'old_str': {
|
||||
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
|
||||
'type': 'string',
|
||||
},
|
||||
'new_str': {
|
||||
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
|
||||
'type': 'string',
|
||||
},
|
||||
'insert_line': {
|
||||
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
|
||||
'type': 'integer',
|
||||
},
|
||||
'view_range': {
|
||||
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
|
||||
'items': {'type': 'integer'},
|
||||
'type': 'array',
|
||||
},
|
||||
StrReplaceEditorTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='str_replace_editor',
|
||||
description=_STR_REPLACE_EDITOR_DESCRIPTION,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
|
||||
'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'],
|
||||
'type': 'string',
|
||||
},
|
||||
'path': {
|
||||
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
|
||||
'type': 'string',
|
||||
},
|
||||
'file_text': {
|
||||
'description': 'Required parameter of `create` command, with the content of the file to be created.',
|
||||
'type': 'string',
|
||||
},
|
||||
'old_str': {
|
||||
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
|
||||
'type': 'string',
|
||||
},
|
||||
'new_str': {
|
||||
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
|
||||
'type': 'string',
|
||||
},
|
||||
'insert_line': {
|
||||
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
|
||||
'type': 'integer',
|
||||
},
|
||||
'view_range': {
|
||||
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
|
||||
'items': {'type': 'integer'},
|
||||
'type': 'array',
|
||||
},
|
||||
'required': ['command', 'path'],
|
||||
},
|
||||
),
|
||||
)
|
||||
'required': ['command', 'path'],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
4
openhands/agenthub/delegator_agent/__init__.py
Normal file
4
openhands/agenthub/delegator_agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from openhands.agenthub.delegator_agent.agent import DelegatorAgent
|
||||
from openhands.controller.agent import Agent
|
||||
|
||||
Agent.register('DelegatorAgent', DelegatorAgent)
|
||||
87
openhands/agenthub/delegator_agent/agent.py
Normal file
87
openhands/agenthub/delegator_agent/agent.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction
|
||||
from openhands.events.observation import AgentDelegateObservation, Observation
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class DelegatorAgent(Agent):
|
||||
VERSION = '1.0'
|
||||
"""
|
||||
The Delegator Agent is responsible for delegating tasks to other agents based on the current task.
|
||||
"""
|
||||
|
||||
current_delegate: str = ''
|
||||
|
||||
def __init__(self, llm: LLM, config: AgentConfig):
|
||||
"""Initialize the Delegator Agent with an LLM
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
"""
|
||||
super().__init__(llm, config)
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
"""Checks to see if current step is completed, returns AgentFinishAction if True.
|
||||
Otherwise, delegates the task to the next agent in the pipeline.
|
||||
|
||||
Parameters:
|
||||
- state (State): The current state given the previous actions and observations
|
||||
|
||||
Returns:
|
||||
- AgentFinishAction: If the last state was 'completed', 'verified', or 'abandoned'
|
||||
- AgentDelegateAction: The next agent to delegate the task to
|
||||
"""
|
||||
if self.current_delegate == '':
|
||||
self.current_delegate = 'study'
|
||||
task, _ = state.get_current_user_intent()
|
||||
return AgentDelegateAction(
|
||||
agent='StudyRepoForTaskAgent', inputs={'task': task}
|
||||
)
|
||||
|
||||
# last observation in history should be from the delegate
|
||||
last_observation = None
|
||||
for event in reversed(state.history):
|
||||
if isinstance(event, Observation):
|
||||
last_observation = event
|
||||
break
|
||||
|
||||
if not isinstance(last_observation, AgentDelegateObservation):
|
||||
raise Exception('Last observation is not an AgentDelegateObservation')
|
||||
|
||||
goal, _ = state.get_current_user_intent()
|
||||
if self.current_delegate == 'study':
|
||||
self.current_delegate = 'coder'
|
||||
return AgentDelegateAction(
|
||||
agent='CoderAgent',
|
||||
inputs={
|
||||
'task': goal,
|
||||
'summary': last_observation.outputs['summary'],
|
||||
},
|
||||
)
|
||||
elif self.current_delegate == 'coder':
|
||||
self.current_delegate = 'verifier'
|
||||
return AgentDelegateAction(
|
||||
agent='VerifierAgent',
|
||||
inputs={
|
||||
'task': goal,
|
||||
},
|
||||
)
|
||||
elif self.current_delegate == 'verifier':
|
||||
if (
|
||||
'completed' in last_observation.outputs
|
||||
and last_observation.outputs['completed']
|
||||
):
|
||||
return AgentFinishAction()
|
||||
else:
|
||||
self.current_delegate = 'coder'
|
||||
return AgentDelegateAction(
|
||||
agent='CoderAgent',
|
||||
inputs={
|
||||
'task': goal,
|
||||
'summary': last_observation.outputs['summary'],
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise Exception('Invalid delegate state')
|
||||
@@ -202,7 +202,6 @@ Note:
|
||||
tabs = ''
|
||||
last_obs = None
|
||||
last_action = None
|
||||
set_of_marks = None # Initialize set_of_marks to None
|
||||
|
||||
if len(state.history) == 1:
|
||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
@@ -218,9 +217,6 @@ Note:
|
||||
# agent has responded, task finished.
|
||||
return AgentFinishAction(outputs={'content': event.content})
|
||||
elif isinstance(event, Observation):
|
||||
# Only process BrowserOutputObservation and skip other observation types
|
||||
if not isinstance(event, BrowserOutputObservation):
|
||||
continue
|
||||
last_obs = event
|
||||
|
||||
if len(prev_actions) >= 1: # ignore noop()
|
||||
|
||||
@@ -29,12 +29,7 @@ from openhands.core.exceptions import (
|
||||
from openhands.core.logger import LOG_ALL_EVENTS
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import (
|
||||
EventSource,
|
||||
EventStream,
|
||||
EventStreamSubscriber,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
@@ -47,7 +42,6 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
@@ -95,7 +89,7 @@ class AgentController:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
sid: str | None = None,
|
||||
sid: str = 'default',
|
||||
confirmation_mode: bool = False,
|
||||
initial_state: State | None = None,
|
||||
is_delegate: bool = False,
|
||||
@@ -122,7 +116,7 @@ class AgentController:
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
replay_events: A list of logs to replay.
|
||||
"""
|
||||
self.id = sid or event_stream.sid
|
||||
self.id = sid
|
||||
self.agent = agent
|
||||
self.headless_mode = headless_mode
|
||||
self.is_delegate = is_delegate
|
||||
@@ -293,14 +287,8 @@ class AgentController:
|
||||
return True
|
||||
return False
|
||||
if isinstance(event, Observation):
|
||||
if (
|
||||
isinstance(event, NullObservation)
|
||||
and event.cause is not None
|
||||
and event.cause > 0
|
||||
):
|
||||
return True
|
||||
if isinstance(event, AgentStateChangedObservation) or isinstance(
|
||||
event, NullObservation
|
||||
if isinstance(event, NullObservation) or isinstance(
|
||||
event, AgentStateChangedObservation
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -400,7 +388,6 @@ class AgentController:
|
||||
if observation.llm_metrics is not None:
|
||||
self.agent.llm.metrics.merge(observation.llm_metrics)
|
||||
|
||||
# this happens for runnable actions and microagent actions
|
||||
if self._pending_action and self._pending_action.id == observation.cause:
|
||||
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
return
|
||||
@@ -444,25 +431,6 @@ class AgentController:
|
||||
'debug',
|
||||
f'Extended max iterations to {self.state.max_iterations} after user message',
|
||||
)
|
||||
# try to retrieve microagents relevant to the user message
|
||||
# set pending_action while we search for information
|
||||
|
||||
# if this is the first user message for this agent, matters for the microagent info type
|
||||
first_user_message = self._first_user_message()
|
||||
is_first_user_message = (
|
||||
action.id == first_user_message.id if first_user_message else False
|
||||
)
|
||||
recall_type = (
|
||||
RecallType.WORKSPACE_CONTEXT
|
||||
if is_first_user_message
|
||||
else RecallType.KNOWLEDGE
|
||||
)
|
||||
|
||||
recall_action = RecallAction(query=action.content, recall_type=recall_type)
|
||||
self._pending_action = recall_action
|
||||
# this is source=USER because the user message is the trigger for the microagent retrieval
|
||||
self.event_stream.add_event(recall_action, EventSource.USER)
|
||||
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif action.source == EventSource.AGENT and action.wait_for_response:
|
||||
@@ -470,7 +438,6 @@ class AgentController:
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Resets the agent controller"""
|
||||
# Runnable actions need an Observation
|
||||
# make sure there is an Observation with the tool call metadata to be recognized by the agent
|
||||
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
|
||||
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
|
||||
@@ -492,8 +459,6 @@ class AgentController:
|
||||
obs._cause = self._pending_action.id # type: ignore[attr-defined]
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls
|
||||
|
||||
# reset the pending action, this will be called when the agent is STOPPED or ERROR
|
||||
self._pending_action = None
|
||||
self.agent.reset()
|
||||
@@ -1181,26 +1146,3 @@ class AgentController:
|
||||
result = event.agent_state == AgentState.RUNNING
|
||||
return result
|
||||
return False
|
||||
|
||||
def _first_user_message(self) -> MessageAction | None:
|
||||
"""
|
||||
Get the first user message for this agent.
|
||||
|
||||
For regular agents, this is the first user message from the beginning (start_id=0).
|
||||
For delegate agents, this is the first user message after the delegate's start_id.
|
||||
|
||||
Returns:
|
||||
MessageAction | None: The first user message, or None if no user message found
|
||||
"""
|
||||
# Find the first user message from the appropriate starting point
|
||||
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
|
||||
|
||||
# Get and return the first user message
|
||||
return next(
|
||||
(
|
||||
e
|
||||
for e in user_messages
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -135,7 +135,7 @@ class StuckDetector:
|
||||
# it takes 3 actions and 3 observations to detect a loop
|
||||
# check if the last three actions are the same and result in errors
|
||||
|
||||
if len(last_actions) < 3 or len(last_observations) < 3:
|
||||
if len(last_actions) < 4 or len(last_observations) < 4:
|
||||
return False
|
||||
|
||||
# are the last three actions the "same"?
|
||||
|
||||
@@ -17,7 +17,6 @@ from openhands.core.schema import AgentState
|
||||
from openhands.core.setup import (
|
||||
create_agent,
|
||||
create_controller,
|
||||
create_memory,
|
||||
create_runtime,
|
||||
initialize_repository_for_runtime,
|
||||
)
|
||||
@@ -171,22 +170,13 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
await runtime.connect()
|
||||
|
||||
# Initialize repository if needed
|
||||
repo_directory = None
|
||||
if config.sandbox.selected_repo:
|
||||
repo_directory = initialize_repository_for_runtime(
|
||||
initialize_repository_for_runtime(
|
||||
runtime,
|
||||
agent=agent,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
)
|
||||
|
||||
# when memory is created, it will load the microagents from the selected repository
|
||||
memory = create_memory(
|
||||
runtime=runtime,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
repo_directory=repo_directory,
|
||||
)
|
||||
|
||||
if initial_user_action:
|
||||
# If there's an initial user action, enqueue it and do not prompt again
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
@@ -195,7 +185,7 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
asyncio.create_task(prompt_for_next_task())
|
||||
|
||||
await run_agent_until_done(
|
||||
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
|
||||
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from openhands.core.config.config_utils import (
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.search_config import SearchConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
from openhands.core.config.utils import (
|
||||
finalize_config,
|
||||
@@ -29,7 +28,6 @@ __all__ = [
|
||||
'AppConfig',
|
||||
'LLMConfig',
|
||||
'SandboxConfig',
|
||||
'SearchConfig',
|
||||
'SecurityConfig',
|
||||
'ExtendedConfig',
|
||||
'load_app_config',
|
||||
|
||||
@@ -2,10 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from openhands.core.config.condenser_config import (
|
||||
CondenserConfig,
|
||||
NoOpCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.condenser_config import CondenserConfig, NoOpCondenserConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@@ -33,7 +30,6 @@ class AgentConfig(BaseModel):
|
||||
disabled_microagents: list[str] = Field(default_factory=list)
|
||||
enable_history_truncation: bool = Field(default=True)
|
||||
enable_som_visual_browsing: bool = Field(default=False)
|
||||
enable_search_engine: bool = Field(default=False)
|
||||
condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
@@ -12,7 +12,6 @@ from openhands.core.config.config_utils import (
|
||||
from openhands.core.config.extended_config import ExtendedConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.search_config import SearchConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
|
||||
|
||||
@@ -54,7 +53,6 @@ class AppConfig(BaseModel):
|
||||
default_agent: str = Field(default=OH_DEFAULT_AGENT)
|
||||
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
|
||||
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
||||
search: SearchConfig = Field(default_factory=SearchConfig)
|
||||
extended: ExtendedConfig = Field(default_factory=lambda: ExtendedConfig({}))
|
||||
runtime: str = Field(default='docker')
|
||||
file_store: str = Field(default='local')
|
||||
|
||||
@@ -15,7 +15,6 @@ class SandboxConfig(BaseModel):
|
||||
timeout: The timeout for the default sandbox action execution.
|
||||
remote_runtime_init_timeout: The timeout for the remote runtime to start.
|
||||
remote_runtime_api_timeout: The timeout for the remote runtime API requests.
|
||||
remote_runtime_enable_retries: Whether to enable retries (on recoverable errors like requests.ConnectionError) for the remote runtime API requests.
|
||||
enable_auto_lint: Whether to enable auto-lint.
|
||||
use_host_network: Whether to use the host network.
|
||||
runtime_binding_address: The binding address for the runtime ports. It specifies which network interface on the host machine Docker should bind the runtime ports to.
|
||||
@@ -54,7 +53,7 @@ class SandboxConfig(BaseModel):
|
||||
timeout: int = Field(default=120)
|
||||
remote_runtime_init_timeout: int = Field(default=180)
|
||||
remote_runtime_api_timeout: int = Field(default=10)
|
||||
remote_runtime_enable_retries: bool = Field(default=True)
|
||||
remote_runtime_enable_retries: bool = Field(default=False)
|
||||
remote_runtime_class: str | None = Field(
|
||||
default=None
|
||||
) # can be "None" (default to gvisor) or "sysbox" (support docker inside runtime + more stable)
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Configuration for search engine functionality."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
"""Configuration for search engine functionality.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether search engine functionality is enabled.
|
||||
api_key: The API key for the search engine.
|
||||
api_url: The base URL for the search API.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(default=False)
|
||||
api_key: SecretStr | None = Field(default=None)
|
||||
api_url: str = Field(default="https://api.search.brave.com/res/v1/web/search")
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Post-initialization hook to assign search-related variables to environment variables.
|
||||
|
||||
This ensures that these values are accessible to the search engine at runtime.
|
||||
"""
|
||||
super().model_post_init(__context)
|
||||
|
||||
# Set environment variables for search engine
|
||||
if self.api_key:
|
||||
os.environ["BRAVE_API_KEY"] = self.api_key.get_secret_value()
|
||||
if self.api_url:
|
||||
os.environ["BRAVE_API_URL"] = self.api_url
|
||||
@@ -240,7 +240,7 @@ class SensitiveDataFilter(logging.Filter):
|
||||
if (
|
||||
len(value) > 2
|
||||
and value != 'default'
|
||||
and any(s in key_upper for s in ('SECRET', '_KEY', '_CODE', '_TOKEN'))
|
||||
and any(s in key_upper for s in ('SECRET', 'KEY', 'CODE', 'TOKEN'))
|
||||
):
|
||||
sensitive_values.append(value)
|
||||
|
||||
|
||||
@@ -3,14 +3,12 @@ import asyncio
|
||||
from openhands.controller import AgentController
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
|
||||
async def run_agent_until_done(
|
||||
controller: AgentController,
|
||||
runtime: Runtime,
|
||||
memory: Memory,
|
||||
end_states: list[AgentState],
|
||||
):
|
||||
"""
|
||||
@@ -39,7 +37,6 @@ async def run_agent_until_done(
|
||||
|
||||
runtime.status_callback = status_callback
|
||||
controller.status_callback = status_callback
|
||||
memory.status_callback = status_callback
|
||||
|
||||
while controller.state.agent_state not in end_states:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@@ -18,7 +18,6 @@ from openhands.core.schema import AgentState
|
||||
from openhands.core.setup import (
|
||||
create_agent,
|
||||
create_controller,
|
||||
create_memory,
|
||||
create_runtime,
|
||||
generate_sid,
|
||||
initialize_repository_for_runtime,
|
||||
@@ -30,7 +29,6 @@ from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_from_dict
|
||||
from openhands.io import read_input, read_task
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
@@ -53,7 +51,6 @@ async def run_controller(
|
||||
exit_on_message: bool = False,
|
||||
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
||||
headless_mode: bool = True,
|
||||
memory: Memory | None = None,
|
||||
) -> State | None:
|
||||
"""Main coroutine to run the agent controller with task input flexibility.
|
||||
|
||||
@@ -96,8 +93,6 @@ async def run_controller(
|
||||
if agent is None:
|
||||
agent = create_agent(config)
|
||||
|
||||
# when the runtime is created, it will be connected and clone the selected repository
|
||||
repo_directory = None
|
||||
if runtime is None:
|
||||
runtime = create_runtime(
|
||||
config,
|
||||
@@ -110,23 +105,14 @@ async def run_controller(
|
||||
|
||||
# Initialize repository if needed
|
||||
if config.sandbox.selected_repo:
|
||||
repo_directory = initialize_repository_for_runtime(
|
||||
initialize_repository_for_runtime(
|
||||
runtime,
|
||||
agent=agent,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
# when memory is created, it will load the microagents from the selected repository
|
||||
if memory is None:
|
||||
memory = create_memory(
|
||||
runtime=runtime,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
selected_repository=config.sandbox.selected_repo,
|
||||
repo_directory=repo_directory,
|
||||
)
|
||||
|
||||
replay_events: list[Event] | None = None
|
||||
if config.replay_trajectory_path:
|
||||
logger.info('Trajectory replay is enabled')
|
||||
@@ -186,7 +172,7 @@ async def run_controller(
|
||||
]
|
||||
|
||||
try:
|
||||
await run_agent_until_done(controller, runtime, memory, end_states)
|
||||
await run_agent_until_done(controller, runtime, end_states)
|
||||
except Exception as e:
|
||||
logger.error(f'Exception in main loop: {e}')
|
||||
|
||||
|
||||
@@ -82,11 +82,5 @@ class ActionTypeSchema(BaseModel):
|
||||
SEND_PR: str = Field(default='send_pr')
|
||||
"""Send a PR to github."""
|
||||
|
||||
SEARCH: str = Field(default='search')
|
||||
"""Queries a search engine."""
|
||||
|
||||
RECALL: str = Field(default='recall')
|
||||
"""Retrieves content from a user workspace, microagent, or other source."""
|
||||
|
||||
|
||||
ActionType = ActionTypeSchema()
|
||||
|
||||
@@ -49,11 +49,5 @@ class ObservationTypeSchema(BaseModel):
|
||||
CONDENSE: str = Field(default='condense')
|
||||
"""Result of a condensation operation."""
|
||||
|
||||
SEARCH: str = Field(default='search')
|
||||
"""Result of querying a search engine."""
|
||||
|
||||
RECALL: str = Field(default='recall')
|
||||
"""Result of a recall operation. This can be the workspace context, a microagent, or other types of information."""
|
||||
|
||||
|
||||
ObservationType = ObservationTypeSchema()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable, Tuple, Type
|
||||
from typing import Tuple, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -16,7 +16,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroAgent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
@@ -84,6 +83,7 @@ def create_runtime(
|
||||
|
||||
def initialize_repository_for_runtime(
|
||||
runtime: Runtime,
|
||||
agent: Agent | None = None,
|
||||
selected_repository: str | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
) -> str | None:
|
||||
@@ -91,6 +91,7 @@ def initialize_repository_for_runtime(
|
||||
|
||||
Args:
|
||||
runtime: The runtime to initialize the repository for.
|
||||
agent: (optional) The agent to load microagents for.
|
||||
selected_repository: (optional) The GitHub repository to use.
|
||||
github_token: (optional) The GitHub token to use.
|
||||
|
||||
@@ -98,10 +99,10 @@ def initialize_repository_for_runtime(
|
||||
The repository directory path if a repository was cloned, None otherwise.
|
||||
"""
|
||||
# clone selected repository if provided
|
||||
repo_directory = None
|
||||
github_token = (
|
||||
SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token
|
||||
)
|
||||
repo_directory = None
|
||||
if selected_repository and github_token:
|
||||
logger.debug(f'Selected repository {selected_repository}.')
|
||||
repo_directory = runtime.clone_repo(
|
||||
@@ -110,47 +111,16 @@ def initialize_repository_for_runtime(
|
||||
None,
|
||||
)
|
||||
|
||||
return repo_directory
|
||||
|
||||
|
||||
def create_memory(
|
||||
runtime: Runtime,
|
||||
event_stream: EventStream,
|
||||
sid: str,
|
||||
selected_repository: str | None = None,
|
||||
repo_directory: str | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
) -> Memory:
|
||||
"""Create a memory for the agent to use.
|
||||
|
||||
Args:
|
||||
runtime: The runtime to use.
|
||||
event_stream: The event stream it will subscribe to.
|
||||
sid: The session id.
|
||||
selected_repository: The repository to clone and start with, if any.
|
||||
repo_directory: The repository directory, if any.
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
"""
|
||||
memory = Memory(
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
status_callback=status_callback,
|
||||
)
|
||||
|
||||
if runtime:
|
||||
# sets available hosts
|
||||
memory.set_runtime_info(runtime)
|
||||
|
||||
# loads microagents from repo/.openhands/microagents
|
||||
# load microagents from selected repository
|
||||
if agent and agent.prompt_manager and selected_repository and repo_directory:
|
||||
agent.prompt_manager.set_runtime_info(runtime)
|
||||
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
|
||||
selected_repository
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
agent.prompt_manager.load_microagents(microagents)
|
||||
agent.prompt_manager.set_repository_info(selected_repository, repo_directory)
|
||||
|
||||
if selected_repository and repo_directory:
|
||||
memory.set_repository_info(selected_repository, repo_directory)
|
||||
|
||||
return memory
|
||||
return repo_directory
|
||||
|
||||
|
||||
def create_agent(config: AppConfig) -> Agent:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
|
||||
__all__ = [
|
||||
@@ -6,5 +6,4 @@ __all__ = [
|
||||
'EventSource',
|
||||
'EventStream',
|
||||
'EventStreamSubscriber',
|
||||
'RecallType',
|
||||
]
|
||||
|
||||
@@ -6,7 +6,6 @@ from openhands.events.action.agent import (
|
||||
AgentSummarizeAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
|
||||
@@ -17,7 +16,6 @@ from openhands.events.action.files import (
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.search_engine import SearchAction
|
||||
|
||||
__all__ = [
|
||||
'Action',
|
||||
@@ -37,6 +35,4 @@ __all__ = [
|
||||
'MessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
'AgentThinkAction',
|
||||
'SearchAction',
|
||||
'RecallAction',
|
||||
]
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -107,22 +106,3 @@ class AgentDelegateAction(Action):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"I'm asking {self.agent} for help with this task."
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecallAction(Action):
|
||||
"""This action is used for retrieving content, e.g., from the global directory or user workspace."""
|
||||
|
||||
recall_type: RecallType
|
||||
query: str = ''
|
||||
thought: str = ''
|
||||
action: str = ActionType.RECALL
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Retrieving content for: {self.query[:50]}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**RecallAction**\n'
|
||||
ret += f'QUERY: {self.query[:50]}'
|
||||
return ret
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchAction(Action):
|
||||
query: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.SEARCH
|
||||
runnable: ClassVar[bool] = True
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I am querying the search engine to search for {self.query}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**SearchAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'QUERY: {self.query}'
|
||||
return ret
|
||||
@@ -22,16 +22,6 @@ class FileReadSource(str, Enum):
|
||||
DEFAULT = 'default'
|
||||
|
||||
|
||||
class RecallType(str, Enum):
|
||||
"""The type of information that can be retrieved from microagents."""
|
||||
|
||||
WORKSPACE_CONTEXT = 'workspace_context'
|
||||
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||
|
||||
KNOWLEDGE = 'knowledge'
|
||||
"""A knowledge microagent."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
INVALID_ID = -1
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.search_engine import SearchEngineObservation
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
@@ -43,7 +40,4 @@ __all__ = [
|
||||
'SuccessObservation',
|
||||
'UserRejectObservation',
|
||||
'AgentCondensationObservation',
|
||||
'SearchEngineObservation',
|
||||
'RecallObservation',
|
||||
'RecallType',
|
||||
]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@@ -41,90 +40,3 @@ class AgentThinkObservation(Observation):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroagentKnowledge:
|
||||
"""
|
||||
Represents knowledge from a triggered microagent.
|
||||
|
||||
Attributes:
|
||||
name: The name of the microagent that was triggered
|
||||
trigger: The word that triggered this microagent
|
||||
content: The actual content/knowledge from the microagent
|
||||
"""
|
||||
|
||||
name: str
|
||||
trigger: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecallObservation(Observation):
|
||||
"""The retrieval of content from a microagent or more microagents."""
|
||||
|
||||
recall_type: RecallType
|
||||
observation: str = ObservationType.RECALL
|
||||
|
||||
# workspace context
|
||||
repo_name: str = ''
|
||||
repo_directory: str = ''
|
||||
repo_instructions: str = ''
|
||||
runtime_hosts: dict[str, int] = field(default_factory=dict)
|
||||
additional_agent_instructions: str = ''
|
||||
|
||||
# knowledge
|
||||
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
|
||||
"""
|
||||
A list of MicroagentKnowledge objects, each containing information from a triggered microagent.
|
||||
|
||||
Example:
|
||||
[
|
||||
MicroagentKnowledge(
|
||||
name="python_best_practices",
|
||||
trigger="python",
|
||||
content="Always use virtual environments for Python projects."
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name="git_workflow",
|
||||
trigger="git",
|
||||
content="Create a new branch for each feature or bugfix."
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return (
|
||||
'Added workspace context'
|
||||
if self.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
else 'Added microagent knowledge'
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Build a string representation
|
||||
fields = []
|
||||
if self.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||
fields.extend(
|
||||
[
|
||||
f'recall_type={self.recall_type}',
|
||||
f'repo_name={self.repo_name}',
|
||||
f'repo_instructions={self.repo_instructions[:20]}...',
|
||||
f'runtime_hosts={self.runtime_hosts}',
|
||||
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
|
||||
]
|
||||
)
|
||||
else:
|
||||
fields.extend(
|
||||
[
|
||||
f'recall_type={self.recall_type}',
|
||||
]
|
||||
)
|
||||
if self.microagent_knowledge:
|
||||
fields.extend(
|
||||
[
|
||||
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}',
|
||||
]
|
||||
)
|
||||
|
||||
return f'**RecallObservation**\n{", ".join(fields)}'
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchEngineObservation(Observation):
|
||||
query: str
|
||||
observation: str = ObservationType.SEARCH
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Searched for: {self.query}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = (
|
||||
'**SearchEngineObservation**\n'
|
||||
f'Query: {self.query}\n'
|
||||
f'Search Results: {self.content}\n'
|
||||
)
|
||||
return ret
|
||||
@@ -8,7 +8,6 @@ from openhands.events.action.agent import (
|
||||
AgentRejectAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import (
|
||||
@@ -22,7 +21,6 @@ from openhands.events.action.files import (
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.search_engine import SearchAction
|
||||
|
||||
actions = (
|
||||
NullAction,
|
||||
@@ -37,10 +35,8 @@ actions = (
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentDelegateAction,
|
||||
RecallAction,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
SearchAction,
|
||||
)
|
||||
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -103,8 +102,6 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
d['timestamp'] = d['timestamp'].isoformat()
|
||||
if key == 'source' and 'source' in d:
|
||||
d['source'] = d['source'].value
|
||||
if key == 'recall_type' and 'recall_type' in d:
|
||||
d['recall_type'] = d['recall_type'].value
|
||||
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
|
||||
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
|
||||
if key == 'llm_metrics' and 'llm_metrics' in d:
|
||||
@@ -122,11 +119,7 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
|
||||
# such as CmdOutputMetadata
|
||||
# we serialize it along with the rest
|
||||
# we also handle the Enum conversion for RecallObservation
|
||||
d['extras'] = {
|
||||
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
|
||||
for k, v in props.items()
|
||||
}
|
||||
d['extras'] = {k: _convert_pydantic_to_dict(v) for k, v in props.items()}
|
||||
# Include success field for CmdOutputObservation
|
||||
if hasattr(event, 'success'):
|
||||
d['success'] = event.success
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import copy
|
||||
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
@@ -43,7 +40,6 @@ observations = (
|
||||
UserRejectObservation,
|
||||
AgentCondensationObservation,
|
||||
AgentThinkObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
|
||||
OBSERVATION_TYPE_TO_CLASS = {
|
||||
@@ -114,18 +110,4 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
else:
|
||||
extras['metadata'] = CmdOutputMetadata()
|
||||
|
||||
if observation_class is RecallObservation:
|
||||
# handle the Enum conversion
|
||||
if 'recall_type' in extras:
|
||||
extras['recall_type'] = RecallType(extras['recall_type'])
|
||||
|
||||
# convert dicts in microagent_knowledge to MicroagentKnowledge objects
|
||||
if 'microagent_knowledge' in extras and isinstance(
|
||||
extras['microagent_knowledge'], list
|
||||
):
|
||||
extras['microagent_knowledge'] = [
|
||||
MicroagentKnowledge(**item) if isinstance(item, dict) else item
|
||||
for item in extras['microagent_knowledge']
|
||||
]
|
||||
|
||||
return observation_class(content=content, **extras)
|
||||
|
||||
@@ -27,7 +27,6 @@ class EventStreamSubscriber(str, Enum):
|
||||
RESOLVER = 'openhands_resolver'
|
||||
SERVER = 'server'
|
||||
RUNTIME = 'runtime'
|
||||
MEMORY = 'memory'
|
||||
MAIN = 'main'
|
||||
TEST = 'test'
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
@@ -16,7 +15,7 @@ from openhands.integrations.service_types import (
|
||||
User,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
class GitHubService(GitService):
|
||||
BASE_URL = 'https://api.github.com'
|
||||
@@ -26,7 +25,6 @@ class GitHubService(GitService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
|
||||
@@ -249,8 +249,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
# if we mocked function calling, and we have tools, convert the response back to function calling format
|
||||
if mock_function_calling and mock_fncall_tools is not None:
|
||||
logger.debug(f'Response choices: {len(resp.choices)}')
|
||||
assert len(resp.choices) >= 1
|
||||
assert len(resp.choices) == 1
|
||||
non_fncall_response_message = resp.choices[0].message
|
||||
fn_call_messages_with_response = (
|
||||
convert_non_fncall_messages_to_fncall_messages(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from litellm import ModelResponse
|
||||
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.core.schema import ActionType
|
||||
@@ -17,7 +16,7 @@ from openhands.events.action import (
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
@@ -27,24 +26,18 @@ from openhands.events.observation import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
IPythonRunCellObservation,
|
||||
SearchEngineObservation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
class ConversationMemory:
|
||||
"""Processes event history into a coherent conversation for the agent."""
|
||||
|
||||
def __init__(self, config: AgentConfig, prompt_manager: PromptManager):
|
||||
self.agent_config = config
|
||||
def __init__(self, prompt_manager: PromptManager):
|
||||
self.prompt_manager = prompt_manager
|
||||
|
||||
def process_events(
|
||||
@@ -53,24 +46,23 @@ class ConversationMemory:
|
||||
initial_messages: list[Message],
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
enable_som_visual_browsing: bool = False,
|
||||
) -> list[Message]:
|
||||
"""Process state history into a list of messages for the LLM.
|
||||
|
||||
Ensures that tool call actions are processed correctly in function calling mode.
|
||||
|
||||
Args:
|
||||
condensed_history: The condensed history of events to convert
|
||||
initial_messages: The initial messages to include in the conversation
|
||||
state: The state containing the history of events to convert
|
||||
condensed_history: The condensed list of events to process
|
||||
initial_messages: The initial messages to include in the result
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
|
||||
"""
|
||||
|
||||
events = condensed_history
|
||||
|
||||
# log visual browsing status
|
||||
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||
|
||||
# Process special events first (system prompts, etc.)
|
||||
messages = initial_messages
|
||||
|
||||
@@ -78,7 +70,7 @@ class ConversationMemory:
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
tool_call_id_to_message: dict[str, Message] = {}
|
||||
|
||||
for i, event in enumerate(events):
|
||||
for event in events:
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
messages_to_add = self._process_action(
|
||||
@@ -92,9 +84,7 @@ class ConversationMemory:
|
||||
tool_call_id_to_message=tool_call_id_to_message,
|
||||
max_message_chars=max_message_chars,
|
||||
vision_is_active=vision_is_active,
|
||||
enable_som_visual_browsing=self.agent_config.enable_som_visual_browsing,
|
||||
current_index=i,
|
||||
events=events,
|
||||
enable_som_visual_browsing=enable_som_visual_browsing,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {type(event)}')
|
||||
@@ -280,8 +270,6 @@ class ConversationMemory:
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
enable_som_visual_browsing: bool = False,
|
||||
current_index: int = 0,
|
||||
events: list[Event] | None = None,
|
||||
) -> list[Message]:
|
||||
"""Converts an observation into a message format that can be sent to the LLM.
|
||||
|
||||
@@ -303,8 +291,6 @@ class ConversationMemory:
|
||||
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
|
||||
current_index: The index of the current event in the events list (for deduplication)
|
||||
events: The list of all events (for deduplication)
|
||||
|
||||
Returns:
|
||||
list[Message]: A list containing the formatted message(s) for the observation.
|
||||
@@ -386,122 +372,6 @@ class ConversationMemory:
|
||||
elif isinstance(obs, AgentCondensationObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, SearchEngineObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif (
|
||||
isinstance(obs, RecallObservation)
|
||||
and self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||
# everything is optional, check if they are present
|
||||
if obs.repo_name or obs.repo_directory:
|
||||
repo_info = RepositoryInfo(
|
||||
repo_name=obs.repo_name or '',
|
||||
repo_directory=obs.repo_directory or '',
|
||||
)
|
||||
else:
|
||||
repo_info = None
|
||||
|
||||
if obs.runtime_hosts or obs.additional_agent_instructions:
|
||||
runtime_info = RuntimeInfo(
|
||||
available_hosts=obs.runtime_hosts,
|
||||
additional_agent_instructions=obs.additional_agent_instructions,
|
||||
)
|
||||
else:
|
||||
runtime_info = None
|
||||
|
||||
repo_instructions = (
|
||||
obs.repo_instructions if obs.repo_instructions else ''
|
||||
)
|
||||
|
||||
# Have some meaningful content before calling the template
|
||||
has_repo_info = repo_info is not None and (
|
||||
repo_info.repo_name or repo_info.repo_directory
|
||||
)
|
||||
has_runtime_info = runtime_info is not None and (
|
||||
runtime_info.available_hosts
|
||||
or runtime_info.additional_agent_instructions
|
||||
)
|
||||
has_repo_instructions = bool(repo_instructions.strip())
|
||||
|
||||
# Filter and process microagent knowledge
|
||||
filtered_agents = []
|
||||
if obs.microagent_knowledge:
|
||||
# Exclude disabled microagents
|
||||
filtered_agents = [
|
||||
agent
|
||||
for agent in obs.microagent_knowledge
|
||||
if agent.name not in self.agent_config.disabled_microagents
|
||||
]
|
||||
|
||||
has_microagent_knowledge = bool(filtered_agents)
|
||||
|
||||
# Generate appropriate content based on what is present
|
||||
message_content = []
|
||||
|
||||
# Build the workspace context information
|
||||
if has_repo_info or has_runtime_info or has_repo_instructions:
|
||||
formatted_workspace_text = (
|
||||
self.prompt_manager.build_workspace_context(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
repo_instructions=repo_instructions,
|
||||
)
|
||||
)
|
||||
message_content.append(TextContent(text=formatted_workspace_text))
|
||||
|
||||
# Add microagent knowledge if present
|
||||
if has_microagent_knowledge:
|
||||
formatted_microagent_text = (
|
||||
self.prompt_manager.build_microagent_info(
|
||||
triggered_agents=filtered_agents,
|
||||
)
|
||||
)
|
||||
message_content.append(TextContent(text=formatted_microagent_text))
|
||||
|
||||
# Return the combined message if we have any content
|
||||
if message_content:
|
||||
message = Message(role='user', content=message_content)
|
||||
else:
|
||||
return []
|
||||
elif obs.recall_type == RecallType.KNOWLEDGE:
|
||||
# Use prompt manager to build the microagent info
|
||||
# First, filter out agents that appear in earlier RecallObservations
|
||||
filtered_agents = self._filter_agents_in_microagent_obs(
|
||||
obs, current_index, events or []
|
||||
)
|
||||
|
||||
# Create and return a message if there is microagent knowledge to include
|
||||
if filtered_agents:
|
||||
# Exclude disabled microagents
|
||||
filtered_agents = [
|
||||
agent
|
||||
for agent in filtered_agents
|
||||
if agent.name not in self.agent_config.disabled_microagents
|
||||
]
|
||||
|
||||
# Only proceed if we still have agents after filtering out disabled ones
|
||||
if filtered_agents:
|
||||
formatted_text = self.prompt_manager.build_microagent_info(
|
||||
triggered_agents=filtered_agents,
|
||||
)
|
||||
|
||||
return [
|
||||
Message(
|
||||
role='user', content=[TextContent(text=formatted_text)]
|
||||
)
|
||||
]
|
||||
|
||||
# Return empty list if no microagents to include or all were disabled
|
||||
return []
|
||||
elif (
|
||||
isinstance(obs, RecallObservation)
|
||||
and not self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
# If prompt extensions are disabled, we don't add any additional info
|
||||
# TODO: test this
|
||||
return []
|
||||
else:
|
||||
# If an observation message is not returned, it will cause an error
|
||||
# when the LLM tries to return the next message
|
||||
@@ -534,51 +404,3 @@ class ConversationMemory:
|
||||
-1
|
||||
].cache_prompt = True # Last item inside the message content
|
||||
break
|
||||
|
||||
def _filter_agents_in_microagent_obs(
|
||||
self, obs: RecallObservation, current_index: int, events: list[Event]
|
||||
) -> list[MicroagentKnowledge]:
|
||||
"""Filter out agents that appear in earlier RecallObservations.
|
||||
|
||||
Args:
|
||||
obs: The current RecallObservation to filter
|
||||
current_index: The index of the current event in the events list
|
||||
events: The list of all events
|
||||
|
||||
Returns:
|
||||
list[MicroagentKnowledge]: The filtered list of microagent knowledge
|
||||
"""
|
||||
if obs.recall_type != RecallType.KNOWLEDGE:
|
||||
return obs.microagent_knowledge
|
||||
|
||||
# For each agent in the current microagent observation, check if it appears in any earlier microagent observation
|
||||
filtered_agents = []
|
||||
for agent in obs.microagent_knowledge:
|
||||
# Keep this agent if it doesn't appear in any earlier observation
|
||||
# that is, if this is the first microagent observation with this microagent
|
||||
if not self._has_agent_in_earlier_events(agent.name, current_index, events):
|
||||
filtered_agents.append(agent)
|
||||
|
||||
return filtered_agents
|
||||
|
||||
def _has_agent_in_earlier_events(
|
||||
self, agent_name: str, current_index: int, events: list[Event]
|
||||
) -> bool:
|
||||
"""Check if an agent appears in any earlier RecallObservation in the event list.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent to look for
|
||||
current_index: The index of the current event in the events list
|
||||
events: The list of all events
|
||||
|
||||
Returns:
|
||||
bool: True if the agent appears in an earlier RecallObservation, False otherwise
|
||||
"""
|
||||
for event in events[:current_index]:
|
||||
# Note that this check includes the WORKSPACE_CONTEXT
|
||||
if isinstance(event, RecallObservation):
|
||||
if any(
|
||||
agent.name == agent_name for agent in event.microagent_knowledge
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1,292 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
import openhands
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
from openhands.microagent import (
|
||||
BaseMicroAgent,
|
||||
KnowledgeMicroAgent,
|
||||
RepoMicroAgent,
|
||||
load_microagents_from_dir,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.prompt import RepositoryInfo, RuntimeInfo
|
||||
|
||||
GLOBAL_MICROAGENTS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'microagents',
|
||||
)
|
||||
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Memory is a component that listens to the EventStream for information retrieval actions
|
||||
(a RecallAction) and publishes observations with the content (such as RecallObservation).
|
||||
"""
|
||||
|
||||
sid: str
|
||||
event_stream: EventStream
|
||||
status_callback: Callable | None
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_stream: EventStream,
|
||||
sid: str,
|
||||
status_callback: Callable | None = None,
|
||||
):
|
||||
self.event_stream = event_stream
|
||||
self.sid = sid if sid else str(uuid.uuid4())
|
||||
self.status_callback = status_callback
|
||||
self.loop = None
|
||||
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY,
|
||||
self.on_event,
|
||||
self.sid,
|
||||
)
|
||||
|
||||
# Additional placeholders to store user workspace microagents
|
||||
self.repo_microagents: dict[str, RepoMicroAgent] = {}
|
||||
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
|
||||
|
||||
# Store repository / runtime info to send them to the templating later
|
||||
self.repository_info: RepositoryInfo | None = None
|
||||
self.runtime_info: RuntimeInfo | None = None
|
||||
|
||||
# Load global microagents (Knowledge + Repo)
|
||||
# from typically OpenHands/microagents (i.e., the PUBLIC microagents)
|
||||
self._load_global_microagents()
|
||||
|
||||
def on_event(self, event: Event):
|
||||
"""Handle an event from the event stream."""
|
||||
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||
|
||||
async def _on_event(self, event: Event):
|
||||
"""Handle an event from the event stream asynchronously."""
|
||||
try:
|
||||
if isinstance(event, RecallAction):
|
||||
# if this is a workspace context recall (on first user message)
|
||||
# create and add a RecallObservation
|
||||
# with info about repo, runtime, instructions, etc. including microagent knowledge if any
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and event.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
):
|
||||
logger.debug('Workspace context recall')
|
||||
workspace_obs: RecallObservation | NullObservation | None = None
|
||||
|
||||
workspace_obs = self._on_workspace_context_recall(event)
|
||||
if workspace_obs is None:
|
||||
workspace_obs = NullObservation(content='')
|
||||
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
workspace_obs._cause = event.id # type: ignore[union-attr]
|
||||
|
||||
self.event_stream.add_event(workspace_obs, EventSource.ENVIRONMENT)
|
||||
return
|
||||
|
||||
# Handle knowledge recall (triggered microagents)
|
||||
elif (
|
||||
event.source == EventSource.USER
|
||||
and event.recall_type == RecallType.KNOWLEDGE
|
||||
):
|
||||
logger.debug('Microagent knowledge recall')
|
||||
microagent_obs: RecallObservation | NullObservation | None = None
|
||||
microagent_obs = self._on_microagent_recall(event)
|
||||
if microagent_obs is None:
|
||||
microagent_obs = NullObservation(content='')
|
||||
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
microagent_obs._cause = event.id # type: ignore[union-attr]
|
||||
|
||||
self.event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
return
|
||||
except Exception as e:
|
||||
error_str = f'Error: {str(e.__class__.__name__)}'
|
||||
logger.error(error_str)
|
||||
self.send_error_message('STATUS$ERROR_MEMORY', error_str)
|
||||
return
|
||||
|
||||
def _on_workspace_context_recall(
|
||||
self, event: RecallAction
|
||||
) -> RecallObservation | None:
|
||||
"""Add repository and runtime information to the stream as a RecallObservation."""
|
||||
|
||||
# Create WORKSPACE_CONTEXT info:
|
||||
# - repository_info
|
||||
# - runtime_info
|
||||
# - repository_instructions
|
||||
# - microagent_knowledge
|
||||
|
||||
# Collect raw repository instructions
|
||||
repo_instructions = ''
|
||||
assert (
|
||||
len(self.repo_microagents) <= 1
|
||||
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
||||
|
||||
# Retrieve the context of repo instructions
|
||||
for microagent in self.repo_microagents.values():
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
# Find any matched microagents based on the query
|
||||
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||
|
||||
# Create observation if we have anything
|
||||
if (
|
||||
self.repository_info
|
||||
or self.runtime_info
|
||||
or repo_instructions
|
||||
or microagent_knowledge
|
||||
):
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name=self.repository_info.repo_name
|
||||
if self.repository_info and self.repository_info.repo_name is not None
|
||||
else '',
|
||||
repo_directory=self.repository_info.repo_directory
|
||||
if self.repository_info
|
||||
and self.repository_info.repo_directory is not None
|
||||
else '',
|
||||
repo_instructions=repo_instructions if repo_instructions else '',
|
||||
runtime_hosts=self.runtime_info.available_hosts
|
||||
if self.runtime_info and self.runtime_info.available_hosts is not None
|
||||
else {},
|
||||
additional_agent_instructions=self.runtime_info.additional_agent_instructions
|
||||
if self.runtime_info
|
||||
and self.runtime_info.additional_agent_instructions is not None
|
||||
else '',
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Added workspace context',
|
||||
)
|
||||
return obs
|
||||
return None
|
||||
|
||||
def _on_microagent_recall(
|
||||
self,
|
||||
event: RecallAction,
|
||||
) -> RecallObservation | None:
|
||||
"""When a microagent action triggers microagents, create a RecallObservation with structured data."""
|
||||
|
||||
# Find any matched microagents based on the query
|
||||
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||
|
||||
# Create observation if we have anything
|
||||
if microagent_knowledge:
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
return obs
|
||||
return None
|
||||
|
||||
def _find_microagent_knowledge(self, query: str) -> list[MicroagentKnowledge]:
|
||||
"""Find microagent knowledge based on a query.
|
||||
|
||||
Args:
|
||||
query: The query to search for microagent triggers
|
||||
|
||||
Returns:
|
||||
A list of MicroagentKnowledge objects for matched triggers
|
||||
"""
|
||||
recalled_content: list[MicroagentKnowledge] = []
|
||||
|
||||
# skip empty queries
|
||||
if not query:
|
||||
return recalled_content
|
||||
|
||||
# Search for microagent triggers in the query
|
||||
for name, microagent in self.knowledge_microagents.items():
|
||||
trigger = microagent.match_trigger(query)
|
||||
if trigger:
|
||||
logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger)
|
||||
recalled_content.append(
|
||||
MicroagentKnowledge(
|
||||
name=microagent.name,
|
||||
trigger=trigger,
|
||||
content=microagent.content,
|
||||
)
|
||||
)
|
||||
return recalled_content
|
||||
|
||||
def load_user_workspace_microagents(
|
||||
self, user_microagents: list[BaseMicroAgent]
|
||||
) -> None:
|
||||
"""
|
||||
This method loads microagents from a user's cloned repo or workspace directory.
|
||||
|
||||
This is typically called from agent_session or setup once the workspace is cloned.
|
||||
"""
|
||||
logger.info(
|
||||
'Loading user workspace microagents: %s', [m.name for m in user_microagents]
|
||||
)
|
||||
for user_microagent in user_microagents:
|
||||
if isinstance(user_microagent, KnowledgeMicroAgent):
|
||||
self.knowledge_microagents[user_microagent.name] = user_microagent
|
||||
elif isinstance(user_microagent, RepoMicroAgent):
|
||||
self.repo_microagents[user_microagent.name] = user_microagent
|
||||
|
||||
def _load_global_microagents(self) -> None:
|
||||
"""
|
||||
Loads microagents from the global microagents_dir
|
||||
"""
|
||||
repo_agents, knowledge_agents, _ = load_microagents_from_dir(
|
||||
GLOBAL_MICROAGENTS_DIR
|
||||
)
|
||||
for name, agent in knowledge_agents.items():
|
||||
if isinstance(agent, KnowledgeMicroAgent):
|
||||
self.knowledge_microagents[name] = agent
|
||||
for name, agent in repo_agents.items():
|
||||
if isinstance(agent, RepoMicroAgent):
|
||||
self.repo_microagents[name] = agent
|
||||
|
||||
def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
|
||||
"""Store repository info so we can reference it in an observation."""
|
||||
if repo_name or repo_directory:
|
||||
self.repository_info = RepositoryInfo(repo_name, repo_directory)
|
||||
else:
|
||||
self.repository_info = None
|
||||
|
||||
def set_runtime_info(self, runtime: Runtime) -> None:
|
||||
"""Store runtime info (web hosts, ports, etc.)."""
|
||||
# e.g. { '127.0.0.1': 8080 }
|
||||
if runtime.web_hosts or runtime.additional_agent_instructions:
|
||||
self.runtime_info = RuntimeInfo(
|
||||
available_hosts=runtime.web_hosts,
|
||||
additional_agent_instructions=runtime.additional_agent_instructions,
|
||||
)
|
||||
else:
|
||||
self.runtime_info = None
|
||||
|
||||
def send_error_message(self, message_id: str, message: str):
|
||||
"""Sends an error message if the callback function was provided."""
|
||||
if self.status_callback:
|
||||
try:
|
||||
if self.loop is None:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_status_message('error', message_id, message), self.loop
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f'Error sending status message: {e.__class__.__name__}',
|
||||
stack_info=False,
|
||||
)
|
||||
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str):
|
||||
"""Sends a status message to the client."""
|
||||
if self.status_callback:
|
||||
self.status_callback(msg_type, id, message)
|
||||
@@ -41,7 +41,6 @@ from openhands.events.action import (
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
SearchAction,
|
||||
)
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
from openhands.events.observation import (
|
||||
@@ -57,7 +56,6 @@ from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.runtime.browser import browse
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
from openhands.runtime.plugins import ALL_PLUGINS, JupyterPlugin, Plugin, VSCodePlugin
|
||||
from openhands.runtime.search_engine.brave_search import search
|
||||
from openhands.runtime.utils.bash import BashSession
|
||||
from openhands.runtime.utils.files import insert_lines, read_lines
|
||||
from openhands.runtime.utils.memory_monitor import MemoryMonitor
|
||||
@@ -165,6 +163,7 @@ class ActionExecutor:
|
||||
self.start_time = time.time()
|
||||
self.last_execution_time = self.start_time
|
||||
self._initialized = False
|
||||
|
||||
self.max_memory_gb: int | None = None
|
||||
if _override_max_memory_gb := os.environ.get('RUNTIME_MAX_MEMORY_GB', None):
|
||||
self.max_memory_gb = int(_override_max_memory_gb)
|
||||
@@ -465,10 +464,6 @@ class ActionExecutor:
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return await browse(action, self.browser)
|
||||
|
||||
async def search(self, action: SearchAction) -> Observation:
|
||||
obs = await call_sync_from_async(search, action)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
self.memory_monitor.stop_monitoring()
|
||||
if self.bash_session is not None:
|
||||
|
||||
@@ -97,7 +97,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = False,
|
||||
user_id: str | None = None,
|
||||
github_user_id: str | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
@@ -130,7 +130,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor
|
||||
)
|
||||
|
||||
self.user_id = user_id
|
||||
self.github_user_id = github_user_id
|
||||
|
||||
def setup_initial_env(self) -> None:
|
||||
if self.attach_to_existing:
|
||||
@@ -220,9 +220,9 @@ class Runtime(FileEditRuntimeMixin):
|
||||
assert event.timeout is not None
|
||||
try:
|
||||
if isinstance(event, CmdRunAction):
|
||||
if self.user_id and '$GITHUB_TOKEN' in event.command:
|
||||
if self.github_user_id and '$GITHUB_TOKEN' in event.command:
|
||||
gh_client = GithubServiceImpl(
|
||||
external_auth_id=self.user_id, external_token_manager=True
|
||||
user_id=self.github_user_id, external_token_manager=True
|
||||
)
|
||||
token = await gh_client.get_latest_token()
|
||||
if token:
|
||||
|
||||
@@ -24,7 +24,6 @@ from openhands.events.action import (
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
SearchAction,
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.files import FileEditSource
|
||||
@@ -60,7 +59,7 @@ class ActionExecutionClient(Runtime):
|
||||
status_callback: Any | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
github_user_id: str | None = None,
|
||||
):
|
||||
self.session = HttpSession()
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
@@ -76,7 +75,7 @@ class ActionExecutionClient(Runtime):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
github_user_id,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -298,9 +297,6 @@ class ActionExecutionClient(Runtime):
|
||||
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return self.send_action_for_execution(action)
|
||||
|
||||
def search(self, action: SearchAction) -> Observation:
|
||||
return self.send_action_for_execution(action)
|
||||
|
||||
def close(self) -> None:
|
||||
# Make sure we don't close the session multiple times
|
||||
# Can happen in evaluation
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable
|
||||
from urllib.parse import urlparse
|
||||
@@ -46,7 +45,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
github_user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
@@ -57,7 +56,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
github_user_id,
|
||||
)
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
@@ -426,11 +425,10 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
return self._send_action_server_request_impl(method, url, **kwargs)
|
||||
|
||||
retry_decorator = tenacity.retry(
|
||||
retry=tenacity.retry_if_exception_type(requests.ConnectionError),
|
||||
retry=tenacity.retry_if_exception_type(ConnectionError),
|
||||
stop=tenacity.stop_after_attempt(3)
|
||||
| stop_if_should_exit()
|
||||
| self._stop_if_closed,
|
||||
before_sleep=tenacity.before_sleep_log(logger, logging.WARNING),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
)
|
||||
return retry_decorator(self._send_action_server_request_impl)(
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from openhands.runtime.search_engine.brave_search import search
|
||||
|
||||
__all__ = ['search']
|
||||
@@ -1,239 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import requests
|
||||
import tenacity
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events.action import SearchAction
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.search_engine import SearchEngineObservation
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
def get_title(result):
|
||||
return f"### Title: {result['title']}\n" if 'title' in result else ''
|
||||
|
||||
|
||||
def get_url(result):
|
||||
return f"### URL: {result['url']}\n" if 'url' in result else ''
|
||||
|
||||
|
||||
def get_description(result):
|
||||
return (
|
||||
f"### Description: {result['description']}\n" if 'description' in result else ''
|
||||
)
|
||||
|
||||
|
||||
def get_question(result):
|
||||
return f"### Question: {result['question']}\n" if 'question' in result else ''
|
||||
|
||||
|
||||
def get_answer(result):
|
||||
return f"### Answer: {result['answer']}\n" if 'answer' in result else ''
|
||||
|
||||
|
||||
def get_cluster(result):
|
||||
if 'cluster' in result:
|
||||
output = ''
|
||||
for i, result_obj in enumerate(result['cluster']):
|
||||
title = get_title(result_obj)
|
||||
url = get_url(result_obj)
|
||||
description = get_description(result_obj)
|
||||
discussion_output = (
|
||||
f'### Related webpage\n#{title}#{url}#{description}\n'
|
||||
if url != ''
|
||||
else ''
|
||||
)
|
||||
output += discussion_output
|
||||
return output
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def response_to_markdown(results, query):
|
||||
all_results = {}
|
||||
|
||||
# discussions
|
||||
discussion_results = []
|
||||
if 'discussions' in results and 'results' in results['discussions']['results']:
|
||||
for result in results['discussions']['results']:
|
||||
title = get_title(result)
|
||||
url = get_url(result)
|
||||
description = get_description(result)
|
||||
cluster = get_cluster(result)
|
||||
discussion_output = f'## Discussion\n{title}{url}{description}{cluster}\n'
|
||||
discussion_results.append(discussion_output)
|
||||
all_results['discussions'] = discussion_results
|
||||
|
||||
# FAQs
|
||||
faq_results = []
|
||||
if 'faq' in results and 'results' in results['faq']:
|
||||
for result in results['faq']['results']:
|
||||
title = get_title(result)
|
||||
url = get_url(result)
|
||||
question = get_question(result)
|
||||
answer = get_answer(result)
|
||||
faq_output = f'## FAQ\n{title}{url}{question}{answer}\n'
|
||||
faq_results.append(faq_output)
|
||||
all_results['faq'] = faq_results
|
||||
|
||||
# News
|
||||
news_results = []
|
||||
if 'news' in results and 'results' in results['news']:
|
||||
for result in results['news']['results']:
|
||||
title = get_title(result)
|
||||
url = get_url(result)
|
||||
description = get_description(result)
|
||||
news_output = f'## News\n{title}{url}{description}\n'
|
||||
news_results.append(news_output)
|
||||
all_results['news'] = news_results
|
||||
|
||||
# Videos
|
||||
video_results = []
|
||||
if 'videos' in results and 'results' in results['videos']:
|
||||
for result in results['videos']['results']:
|
||||
title = get_title(result)
|
||||
url = get_url(result)
|
||||
description = get_description(result)
|
||||
video_output = f'## Video\n{title}{url}{description}\n'
|
||||
video_results.append(video_output)
|
||||
all_results['videos'] = video_results
|
||||
|
||||
# Web Search Results
|
||||
websearch_results = []
|
||||
if 'web' in results and 'results' in results['web']:
|
||||
for result in results['web']['results']:
|
||||
title = get_title(result)
|
||||
url = get_url(result)
|
||||
description = get_description(result)
|
||||
cluster = get_cluster(result)
|
||||
if cluster:
|
||||
websearch_output = f'## Webpage\n{title}{url}{description}\n{cluster}\n'
|
||||
else:
|
||||
websearch_output = f'## Webpage\n{title}{url}{description}\n'
|
||||
websearch_results.append(websearch_output)
|
||||
all_results['web'] = websearch_results
|
||||
|
||||
# infobox
|
||||
infobox_results = []
|
||||
if 'infobox' in results and 'results' in results['infobox']:
|
||||
for result in results['infobox']['results']:
|
||||
title = get_title(result)
|
||||
url = get_url(result)
|
||||
description = get_description(result)
|
||||
infobox_output = f'## Infobox\n{title}{url}{description}\n'
|
||||
infobox_results.append(infobox_output)
|
||||
all_results['infobox'] = infobox_results
|
||||
|
||||
# locations
|
||||
location_results = []
|
||||
if 'locations' in results and 'results' in results['location']:
|
||||
for result in results['locations']['results']:
|
||||
title = get_title(result)
|
||||
url = get_url(result)
|
||||
description = get_description(result)
|
||||
location_output = f'## Location\n{title}{url}{description}\n'
|
||||
location_results.append(location_output)
|
||||
all_results['locations'] = location_results
|
||||
|
||||
markdown = '# Search Results\n\n'
|
||||
markdown += f'**Searched query**: {query}\n\n'
|
||||
|
||||
# ranked results if available
|
||||
if 'mixed' in results:
|
||||
for rank_type in ['main', 'top', 'side']:
|
||||
if rank_type not in results['mixed']:
|
||||
continue
|
||||
for ranked_result in results['mixed'][rank_type]:
|
||||
result_type = ranked_result['type']
|
||||
if result_type in all_results:
|
||||
include_all = ranked_result['all']
|
||||
idx = ranked_result.get('index', None)
|
||||
if include_all:
|
||||
markdown += ''.join(all_results[result_type])
|
||||
elif idx is not None and idx < len(all_results[result_type]):
|
||||
markdown += all_results[result_type][idx]
|
||||
for result_list in all_results.values():
|
||||
for result in result_list:
|
||||
if result in markdown:
|
||||
continue
|
||||
else:
|
||||
markdown += result
|
||||
else:
|
||||
markdown += ''.join(
|
||||
websearch_results
|
||||
+ video_results
|
||||
+ news_results
|
||||
+ infobox_results
|
||||
+ faq_results
|
||||
+ discussion_results
|
||||
+ location_results
|
||||
)
|
||||
return markdown
|
||||
|
||||
|
||||
def return_error(retry_state: tenacity.RetryCallState):
|
||||
return ErrorObservation('Failed to query Brave Search API.')
|
||||
|
||||
|
||||
@tenacity.retry(
|
||||
wait=tenacity.wait_exponential(min=2, max=10),
|
||||
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
|
||||
retry_error_callback=return_error,
|
||||
)
|
||||
def query_api(query: str, API_KEY, BRAVE_SEARCH_URL):
|
||||
headers = {'Accept': 'application/json', 'X-Subscription-Token': API_KEY}
|
||||
|
||||
params: list[tuple[str, str | int | bool]] = [
|
||||
('q', query),
|
||||
('count', 20), # Number of results to return, max allowed = 20
|
||||
('extra_snippets', False), # TODO: Should we keep it as true?
|
||||
]
|
||||
|
||||
response = requests.get(
|
||||
BRAVE_SEARCH_URL,
|
||||
headers=headers,
|
||||
params=params, # type: ignore
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status() # Raise exception for 4XX/5XX responses
|
||||
results = response.json()
|
||||
markdown_content = response_to_markdown(results, query)
|
||||
# TODO: Handle other types of HTML tags? I couldn't find any other tags in brave search responses for the queries I tried.
|
||||
markdown_content = re.sub(r'</?strong>', '', markdown_content)
|
||||
return SearchEngineObservation(query=query, content=markdown_content)
|
||||
|
||||
|
||||
def search(action: SearchAction, config: AppConfig):
|
||||
"""Execute a search query using the Brave Search API.
|
||||
|
||||
Args:
|
||||
action: The search action containing the query.
|
||||
config: The application configuration.
|
||||
|
||||
Returns:
|
||||
SearchEngineObservation: The search results in markdown format.
|
||||
ErrorObservation: If the query is empty or search is not enabled.
|
||||
"""
|
||||
if not config.search.enabled:
|
||||
return ErrorObservation(
|
||||
content='Search engine functionality is not enabled. Enable it by setting search.enabled=true in config.'
|
||||
)
|
||||
|
||||
query = action.query
|
||||
if query is None or len(query.strip()) == 0:
|
||||
return ErrorObservation(
|
||||
content='The query string for search_engine tool must be a non-empty string.'
|
||||
)
|
||||
|
||||
if config.search.api_key is None:
|
||||
return ErrorObservation(
|
||||
content='Search API key not configured. Set search.api_key in config.'
|
||||
)
|
||||
|
||||
return query_api(
|
||||
query=query,
|
||||
API_KEY=config.search.api_key.get_secret_value(),
|
||||
BRAVE_SEARCH_URL=config.search.api_url
|
||||
)
|
||||
@@ -46,12 +46,7 @@ class ConversationManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def join_conversation(
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
github_user_id: str | None,
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
||||
) -> EventStream | None:
|
||||
"""Join a conversation and return its event stream."""
|
||||
|
||||
@@ -79,7 +74,6 @@ class ConversationManager(ABC):
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
github_user_id: str | None = None,
|
||||
) -> EventStream:
|
||||
"""Start an event loop if one is not already running"""
|
||||
|
||||
|
||||
@@ -106,12 +106,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
return c
|
||||
|
||||
async def join_conversation(
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
github_user_id: str | None,
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
||||
):
|
||||
logger.info(
|
||||
f'join_conversation:{sid}:{connection_id}',
|
||||
@@ -121,9 +116,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
self._local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
return await self.maybe_start_agent_loop(
|
||||
sid, settings, user_id, github_user_id=github_user_id
|
||||
)
|
||||
return await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
for event in event_stream.get_events(reverse=True):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in (
|
||||
@@ -194,18 +187,14 @@ class StandaloneConversationManager(ConversationManager):
|
||||
logger.error('error_cleaning_stale')
|
||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||
|
||||
async def _get_conversation_store(
|
||||
self, user_id: str | None, github_user_id: str | None
|
||||
) -> ConversationStore:
|
||||
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
|
||||
conversation_store_class = self._conversation_store_class
|
||||
if not conversation_store_class:
|
||||
self._conversation_store_class = conversation_store_class = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
self.server_config.conversation_store_class,
|
||||
)
|
||||
store = await conversation_store_class.get_instance(
|
||||
self.config, user_id, github_user_id
|
||||
)
|
||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
||||
return store
|
||||
|
||||
async def get_running_agent_loops(
|
||||
@@ -254,7 +243,6 @@ class StandaloneConversationManager(ConversationManager):
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
github_user_id: str | None = None,
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
|
||||
session: Session | None = None
|
||||
@@ -268,9 +256,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
# Get the conversations sorted (oldest first)
|
||||
conversation_store = await self._get_conversation_store(
|
||||
user_id, github_user_id
|
||||
)
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
conversations = await conversation_store.get_all_metadata(response_ids)
|
||||
conversations.sort(key=_last_updated_at_key, reverse=True)
|
||||
|
||||
@@ -291,9 +277,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(
|
||||
user_id, github_user_id, sid
|
||||
),
|
||||
self._create_conversation_update_callback(user_id, sid),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@@ -390,23 +374,22 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
self, user_id: str | None, github_user_id: str | None, conversation_id: str
|
||||
self, user_id: str | None, conversation_id: str
|
||||
) -> Callable:
|
||||
def callback(*args, **kwargs):
|
||||
call_async_from_sync(
|
||||
self._update_timestamp_for_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
github_user_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def _update_timestamp_for_conversation(
|
||||
self, user_id: str, github_user_id: str, conversation_id: str
|
||||
self, user_id: str, conversation_id: str
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id, github_user_id)
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
conversation.last_updated_at = datetime.now(timezone.utc)
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
@@ -6,14 +6,10 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.observation import (
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import (
|
||||
AgentStateChangedObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.server.shared import (
|
||||
@@ -39,9 +35,7 @@ async def connect(connection_id: str, environ):
|
||||
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
conversation_validator = ConversationValidatorImpl()
|
||||
user_id, github_user_id = await conversation_validator.validate(
|
||||
conversation_id, cookies_str
|
||||
)
|
||||
user_id = await conversation_validator.validate(conversation_id, cookies_str)
|
||||
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
@@ -52,7 +46,7 @@ async def connect(connection_id: str, environ):
|
||||
)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id, connection_id, settings, user_id, github_user_id
|
||||
conversation_id, connection_id, settings, user_id
|
||||
)
|
||||
|
||||
agent_state_changed = None
|
||||
@@ -60,7 +54,10 @@ async def connect(connection_id: str, environ):
|
||||
async for event in async_stream:
|
||||
if isinstance(
|
||||
event,
|
||||
(NullAction, NullObservation, RecallAction, RecallObservation),
|
||||
(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
),
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
|
||||
@@ -10,12 +10,7 @@ from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import (
|
||||
get_access_token,
|
||||
get_github_user_id,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -78,12 +73,12 @@ async def _create_new_conversation(
|
||||
logger.warn('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['provider_token'] = token
|
||||
session_init_args['github_token'] = token or SecretStr('')
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
@@ -105,8 +100,7 @@ async def _create_new_conversation(
|
||||
ConversationMetadata(
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
github_user_id=None,
|
||||
github_user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
@@ -128,10 +122,7 @@ async def _create_new_conversation(
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id,
|
||||
conversation_init_data,
|
||||
user_id,
|
||||
initial_user_msg=initial_message_action,
|
||||
conversation_id, conversation_init_data, user_id, initial_message_action
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
|
||||
@@ -167,7 +158,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
try:
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
get_user_id(request),
|
||||
user_id,
|
||||
github_token,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
@@ -206,7 +197,7 @@ async def search_conversations(
|
||||
limit: int = 20,
|
||||
) -> ConversationInfoResultSet:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
|
||||
@@ -225,7 +216,7 @@ async def search_conversations(
|
||||
conversation.conversation_id for conversation in filtered_results
|
||||
)
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
get_user_id(request), set(conversation_ids)
|
||||
get_github_user_id(request), set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
@@ -245,7 +236,7 @@ async def get_conversation(
|
||||
conversation_id: str, request: Request
|
||||
) -> ConversationInfo | None:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
@@ -261,7 +252,7 @@ async def update_conversation(
|
||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
@@ -277,7 +268,7 @@ async def delete_conversation(
|
||||
request: Request,
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
config, get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
|
||||
@@ -90,38 +90,30 @@ async def store_settings(
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
if settings.unset_github_token:
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
|
||||
# Merge provider tokens with existing ones
|
||||
if settings.unset_github_token: # Only merge if not unsetting tokens
|
||||
settings.secrets_store.provider_tokens = {}
|
||||
settings.provider_tokens = {}
|
||||
else: # Only merge if not unsetting tokens
|
||||
if settings.provider_tokens:
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||
settings.provider_tokens = {
|
||||
provider.value: data.token.get_secret_value()
|
||||
if data.token
|
||||
else None
|
||||
for provider, data in provider_tokens.items()
|
||||
}
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
|
||||
@@ -15,8 +15,7 @@ from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import ChangeAgentStateAction, MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroAgent
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
@@ -53,7 +52,7 @@ class AgentSession:
|
||||
sid: str,
|
||||
file_store: FileStore,
|
||||
status_callback: Callable | None = None,
|
||||
user_id: str | None = None,
|
||||
github_user_id: str | None = None,
|
||||
):
|
||||
"""Initializes a new instance of the Session class
|
||||
|
||||
@@ -66,9 +65,9 @@ class AgentSession:
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
self.file_store = file_store
|
||||
self._status_callback = status_callback
|
||||
self.user_id = user_id
|
||||
self.github_user_id = github_user_id
|
||||
self.logger = OpenHandsLoggerAdapter(
|
||||
extra={'session_id': sid, 'user_id': user_id}
|
||||
extra={'session_id': sid, 'user_id': github_user_id}
|
||||
)
|
||||
|
||||
async def start(
|
||||
@@ -127,15 +126,6 @@ class AgentSession:
|
||||
agent_to_llm_config=agent_to_llm_config,
|
||||
agent_configs=agent_configs,
|
||||
)
|
||||
|
||||
repo_directory = None
|
||||
if self.runtime and runtime_connected and selected_repository:
|
||||
repo_directory = selected_repository.split('/')[-1]
|
||||
self.memory = await self._create_memory(
|
||||
selected_repository=selected_repository,
|
||||
repo_directory=repo_directory,
|
||||
)
|
||||
|
||||
if github_token:
|
||||
self.event_stream.set_secrets(
|
||||
{
|
||||
@@ -241,7 +231,7 @@ class AgentSession:
|
||||
|
||||
kwargs = {}
|
||||
if runtime_cls == RemoteRuntime:
|
||||
kwargs['user_id'] = self.user_id
|
||||
kwargs['github_user_id'] = self.github_user_id
|
||||
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
@@ -270,14 +260,26 @@ class AgentSession:
|
||||
)
|
||||
return False
|
||||
|
||||
repo_directory = None
|
||||
if selected_repository:
|
||||
await call_sync_from_async(
|
||||
repo_directory = await call_sync_from_async(
|
||||
self.runtime.clone_repo,
|
||||
github_token,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
)
|
||||
|
||||
if agent.prompt_manager:
|
||||
agent.prompt_manager.set_runtime_info(self.runtime)
|
||||
microagents: list[BaseMicroAgent] = await call_sync_from_async(
|
||||
self.runtime.get_microagents_from_selected_repo, selected_repository
|
||||
)
|
||||
agent.prompt_manager.load_microagents(microagents)
|
||||
if selected_repository and repo_directory:
|
||||
agent.prompt_manager.set_repository_info(
|
||||
selected_repository, repo_directory
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||
)
|
||||
@@ -340,29 +342,6 @@ class AgentSession:
|
||||
|
||||
return controller
|
||||
|
||||
async def _create_memory(
|
||||
self, selected_repository: str | None, repo_directory: str | None
|
||||
) -> Memory:
|
||||
memory = Memory(
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
status_callback=self._status_callback,
|
||||
)
|
||||
|
||||
if self.runtime:
|
||||
# sets available hosts and other runtime info
|
||||
memory.set_runtime_info(self.runtime)
|
||||
|
||||
# loads microagents from repo/.openhands/microagents
|
||||
microagents: list[BaseMicroAgent] = await call_sync_from_async(
|
||||
self.runtime.get_microagents_from_selected_repo, selected_repository
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
|
||||
if selected_repository and repo_directory:
|
||||
memory.set_repository_info(selected_repository, repo_directory)
|
||||
return memory
|
||||
|
||||
def _maybe_restore_state(self) -> State | None:
|
||||
"""Helper method to handle state restore logic."""
|
||||
restored_state = None
|
||||
|
||||
@@ -8,6 +8,6 @@ class ConversationInitData(Settings):
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
|
||||
provider_token: SecretStr | None = Field(default=None)
|
||||
github_token: SecretStr | None = Field(default=None)
|
||||
selected_repository: str | None = Field(default=None)
|
||||
selected_branch: str | None = Field(default=None)
|
||||
|
||||
@@ -61,7 +61,7 @@ class Session:
|
||||
sid,
|
||||
file_store,
|
||||
status_callback=self.queue_status_message,
|
||||
user_id=user_id,
|
||||
github_user_id=user_id,
|
||||
)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||
@@ -123,11 +123,11 @@ class Session:
|
||||
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
|
||||
provider_token = None
|
||||
github_token = None
|
||||
selected_repository = None
|
||||
selected_branch = None
|
||||
if isinstance(settings, ConversationInitData):
|
||||
provider_token = settings.provider_token
|
||||
github_token = settings.github_token
|
||||
selected_repository = settings.selected_repository
|
||||
selected_branch = settings.selected_branch
|
||||
|
||||
@@ -140,7 +140,7 @@ class Session:
|
||||
max_budget_per_task=self.config.max_budget_per_task,
|
||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
github_token=provider_token,
|
||||
github_token=github_token,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
initial_message=initial_message,
|
||||
|
||||
@@ -43,7 +43,7 @@ class Settings(BaseModel):
|
||||
if context and context.get('expose_secrets', False):
|
||||
return llm_api_key.get_secret_value()
|
||||
|
||||
return pydantic_encoder(llm_api_key) if llm_api_key else None
|
||||
return pydantic_encoder(llm_api_key)
|
||||
|
||||
@staticmethod
|
||||
def _convert_token_value(
|
||||
|
||||
@@ -12,36 +12,25 @@ from openhands.utils.async_utils import wait_all
|
||||
|
||||
|
||||
class ConversationStore(ABC):
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
"""
|
||||
Storage for conversation metadata. May or may not support multiple users depending on the environment
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def save_metadata(self, metadata: ConversationMetadata) -> None:
|
||||
"""Store conversation metadata."""
|
||||
"""Store conversation metadata"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
|
||||
"""Load conversation metadata."""
|
||||
|
||||
async def validate_metadata(
|
||||
self, conversation_id: str, user_id: str, github_user_id: str
|
||||
) -> bool:
|
||||
"""Validate that conversation belongs to the current user."""
|
||||
# TODO: remove github_user_id after transition to Keycloak is complete.
|
||||
metadata = await self.get_metadata(conversation_id)
|
||||
if (not metadata.user_id and not metadata.github_user_id) or (
|
||||
metadata.user_id != user_id and metadata.github_user_id != github_user_id
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
"""Load conversation metadata"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
"""Delete conversation metadata."""
|
||||
"""delete conversation metadata"""
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, conversation_id: str) -> bool:
|
||||
"""Check if conversation exists."""
|
||||
"""Check if conversation exists"""
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
@@ -60,6 +49,6 @@ class ConversationStore(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> ConversationStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
@@ -7,7 +7,7 @@ class ConversationValidator:
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
async def validate(self, conversation_id: str, cookies_str: str):
|
||||
return None, None
|
||||
return None
|
||||
|
||||
|
||||
conversation_validator_cls = os.environ.get(
|
||||
|
||||
@@ -85,8 +85,8 @@ class FileConversationStore(ConversationStore):
|
||||
try:
|
||||
conversations.append(await self.get_metadata(conversation_id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'Could not load conversation metadata: {conversation_id}',
|
||||
logger.error(
|
||||
f'Error loading conversation: {conversation_id}',
|
||||
)
|
||||
conversations.sort(key=_sort_key, reverse=True)
|
||||
conversations = conversations[start:end]
|
||||
@@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import datetime, timezone
|
||||
@dataclass
|
||||
class ConversationMetadata:
|
||||
conversation_id: str
|
||||
user_id: str | None
|
||||
github_user_id: str | None
|
||||
selected_repository: str | None
|
||||
selected_branch: str | None = None
|
||||
|
||||
@@ -1,18 +1,25 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import openhands_logger
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
from openhands.microagent import (
|
||||
BaseMicroAgent,
|
||||
KnowledgeMicroAgent,
|
||||
RepoMicroAgent,
|
||||
load_microagents_from_dir,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeInfo:
|
||||
available_hosts: dict[str, int] = field(default_factory=dict)
|
||||
additional_agent_instructions: str = ''
|
||||
available_hosts: dict[str, int]
|
||||
additional_agent_instructions: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,23 +32,75 @@ class RepositoryInfo:
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
Manages prompt templates and includes information from the user's workspace micro-agents and global micro-agents.
|
||||
Manages prompt templates and micro-agents for AI interactions.
|
||||
|
||||
This class is dedicated to loading and rendering prompts (system prompt, user prompt).
|
||||
This class handles loading and rendering of system and user prompt templates,
|
||||
as well as loading micro-agent specifications. It provides methods to access
|
||||
rendered system and initial user messages for AI interactions.
|
||||
|
||||
Attributes:
|
||||
prompt_dir: Directory containing prompt templates.
|
||||
prompt_dir (str): Directory containing prompt templates.
|
||||
microagent_dir (str): Directory containing microagent specifications.
|
||||
disabled_microagents (list[str] | None): List of microagents to disable. If None, all microagents are enabled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_dir: str,
|
||||
microagent_dir: str | None = None,
|
||||
disabled_microagents: list[str] | None = None,
|
||||
):
|
||||
self.disabled_microagents: list[str] = disabled_microagents or []
|
||||
self.prompt_dir: str = prompt_dir
|
||||
self.repository_info: RepositoryInfo | None = None
|
||||
self.system_template: Template = self._load_template('system_prompt')
|
||||
self.user_template: Template = self._load_template('user_prompt')
|
||||
self.additional_info_template: Template = self._load_template('additional_info')
|
||||
self.microagent_info_template: Template = self._load_template('microagent_info')
|
||||
self.runtime_info = RuntimeInfo(
|
||||
available_hosts={}, additional_agent_instructions=''
|
||||
)
|
||||
|
||||
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
|
||||
self.repo_microagents: dict[str, RepoMicroAgent] = {}
|
||||
|
||||
if microagent_dir:
|
||||
# This loads micro-agents from the microagent_dir
|
||||
# which is typically the OpenHands/microagents (i.e., the PUBLIC microagents)
|
||||
|
||||
# Only load KnowledgeMicroAgents
|
||||
repo_microagents, knowledge_microagents, _ = load_microagents_from_dir(
|
||||
microagent_dir
|
||||
)
|
||||
assert all(
|
||||
isinstance(microagent, KnowledgeMicroAgent)
|
||||
for microagent in knowledge_microagents.values()
|
||||
)
|
||||
for name, microagent in knowledge_microagents.items():
|
||||
if name not in self.disabled_microagents:
|
||||
self.knowledge_microagents[name] = microagent
|
||||
assert all(
|
||||
isinstance(microagent, RepoMicroAgent)
|
||||
for microagent in repo_microagents.values()
|
||||
)
|
||||
for name, microagent in repo_microagents.items():
|
||||
if name not in self.disabled_microagents:
|
||||
self.repo_microagents[name] = microagent
|
||||
|
||||
def load_microagents(self, microagents: list[BaseMicroAgent]) -> None:
|
||||
"""Load microagents from a list of BaseMicroAgents.
|
||||
|
||||
This is typically used when loading microagents from inside a repo.
|
||||
"""
|
||||
openhands_logger.info('Loading microagents: %s', [m.name for m in microagents])
|
||||
# Only keep KnowledgeMicroAgents and RepoMicroAgents
|
||||
for microagent in microagents:
|
||||
if microagent.name in self.disabled_microagents:
|
||||
continue
|
||||
if isinstance(microagent, KnowledgeMicroAgent):
|
||||
self.knowledge_microagents[microagent.name] = microagent
|
||||
elif isinstance(microagent, RepoMicroAgent):
|
||||
self.repo_microagents[microagent.name] = microagent
|
||||
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
if self.prompt_dir is None:
|
||||
@@ -55,6 +114,27 @@ class PromptManager:
|
||||
def get_system_message(self) -> str:
|
||||
return self.system_template.render().strip()
|
||||
|
||||
def set_runtime_info(self, runtime: Runtime) -> None:
|
||||
self.runtime_info.available_hosts = runtime.web_hosts
|
||||
self.runtime_info.additional_agent_instructions = (
|
||||
runtime.additional_agent_instructions
|
||||
)
|
||||
|
||||
def set_repository_info(
|
||||
self,
|
||||
repo_name: str,
|
||||
repo_directory: str,
|
||||
) -> None:
|
||||
"""Sets information about the GitHub repository that has been cloned.
|
||||
|
||||
Args:
|
||||
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
|
||||
repo_directory: The directory where the repository has been cloned
|
||||
"""
|
||||
self.repository_info = RepositoryInfo(
|
||||
repo_name=repo_name, repo_directory=repo_directory
|
||||
)
|
||||
|
||||
def get_example_user_message(self) -> str:
|
||||
"""This is the initial user message provided to the agent
|
||||
before *actual* user instructions are provided.
|
||||
@@ -68,6 +148,45 @@ class PromptManager:
|
||||
|
||||
return self.user_template.render().strip()
|
||||
|
||||
def enhance_message(self, message: Message) -> None:
|
||||
"""Enhance the user message with additional context.
|
||||
|
||||
This method is used to enhance the user message with additional context
|
||||
about the user's task. The additional context will convert the current
|
||||
generic agent into a more specialized agent that is tailored to the user's task.
|
||||
"""
|
||||
if not message.content:
|
||||
return
|
||||
|
||||
# if there were other texts included, they were before the user message
|
||||
# so the last TextContent is the user message
|
||||
# content can be a list of TextContent or ImageContent
|
||||
message_content = ''
|
||||
for content in reversed(message.content):
|
||||
if isinstance(content, TextContent):
|
||||
message_content = content.text
|
||||
break
|
||||
|
||||
if not message_content:
|
||||
return
|
||||
|
||||
triggered_agents = []
|
||||
for name, microagent in self.knowledge_microagents.items():
|
||||
trigger = microagent.match_trigger(message_content)
|
||||
if trigger:
|
||||
openhands_logger.info(
|
||||
"Microagent '%s' triggered by keyword '%s'",
|
||||
name,
|
||||
trigger,
|
||||
)
|
||||
# Create a dictionary with the agent and trigger word
|
||||
triggered_agents.append({'agent': microagent, 'trigger_word': trigger})
|
||||
|
||||
if triggered_agents:
|
||||
formatted_text = self.build_microagent_info(triggered_agents)
|
||||
# Insert the new content at the start of the TextContent list
|
||||
message.content.insert(0, TextContent(text=formatted_text))
|
||||
|
||||
def add_examples_to_initial_message(self, message: Message) -> None:
|
||||
"""Add example_message to the first user message."""
|
||||
example_message = self.get_example_user_message() or None
|
||||
@@ -76,28 +195,44 @@ class PromptManager:
|
||||
if example_message:
|
||||
message.content.insert(0, TextContent(text=example_message))
|
||||
|
||||
def build_workspace_context(
|
||||
def add_info_to_initial_message(
|
||||
self,
|
||||
repository_info: RepositoryInfo | None,
|
||||
runtime_info: RuntimeInfo | None,
|
||||
repo_instructions: str = '',
|
||||
) -> str:
|
||||
"""Renders the additional info template with the stored repository/runtime info."""
|
||||
return self.additional_info_template.render(
|
||||
repository_info=repository_info,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""Adds information about the repository and runtime to the initial user message.
|
||||
|
||||
Args:
|
||||
message: The initial user message to add information to.
|
||||
"""
|
||||
repo_instructions = ''
|
||||
assert (
|
||||
len(self.repo_microagents) <= 1
|
||||
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
||||
for microagent in self.repo_microagents.values():
|
||||
# We assume these are the repo instructions
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
additional_info = self.additional_info_template.render(
|
||||
repository_instructions=repo_instructions,
|
||||
runtime_info=runtime_info,
|
||||
repository_info=self.repository_info,
|
||||
runtime_info=self.runtime_info,
|
||||
).strip()
|
||||
|
||||
# Insert the new content at the start of the TextContent list
|
||||
if additional_info:
|
||||
message.content.insert(0, TextContent(text=additional_info))
|
||||
|
||||
def build_microagent_info(
|
||||
self,
|
||||
triggered_agents: list[MicroagentKnowledge],
|
||||
triggered_agents: list[dict],
|
||||
) -> str:
|
||||
"""Renders the microagent info template with the triggered agents.
|
||||
|
||||
Args:
|
||||
triggered_agents: A list of MicroagentKnowledge objects containing information
|
||||
about triggered microagents.
|
||||
triggered_agents: A list of dictionaries, each containing an "agent"
|
||||
(KnowledgeMicroAgent) and a "trigger_word" (str).
|
||||
"""
|
||||
return self.microagent_info_template.render(
|
||||
triggered_agents=triggered_agents
|
||||
|
||||
74
poetry.lock
generated
74
poetry.lock
generated
@@ -496,18 +496,18 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.37.12"
|
||||
version = "1.37.11"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "boto3-1.37.12-py3-none-any.whl", hash = "sha256:516feaa0d2afaeda1515216fd09291368a1215754bbccb0f28414c0a91a830a2"},
|
||||
{file = "boto3-1.37.12.tar.gz", hash = "sha256:9412d404f103ad6d14f033eb29cd5e0cdca2b9b08cbfa9d4dabd1d7be2de2625"},
|
||||
{file = "boto3-1.37.11-py3-none-any.whl", hash = "sha256:da6c22fc8a7e9bca5d7fc465a877ac3d45b6b086d776bd1a6c55bdde60523741"},
|
||||
{file = "boto3-1.37.11.tar.gz", hash = "sha256:8eec08363ef5db05c2fbf58e89f0c0de6276cda2fdce01e76b3b5f423cd5c0f4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.37.12,<1.38.0"
|
||||
botocore = ">=1.37.11,<1.38.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.11.0,<0.12.0"
|
||||
|
||||
@@ -516,14 +516,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.37.12"
|
||||
version = "1.37.11"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "botocore-1.37.12-py3-none-any.whl", hash = "sha256:ba1948c883bbabe20d95ff62c3e36954c9269686f7db9361857835677ca3e676"},
|
||||
{file = "botocore-1.37.12.tar.gz", hash = "sha256:ae2d5328ce6ad02eb615270507235a6e90fd3eeed615a6c0732b5a68b12f2017"},
|
||||
{file = "botocore-1.37.11-py3-none-any.whl", hash = "sha256:02505309b1235f9f15a6da79103ca224b3f3dc5f6a62f8630fbb2c6ed05e2da8"},
|
||||
{file = "botocore-1.37.11.tar.gz", hash = "sha256:72eb3a9a58b064be26ba154e5e56373633b58f951941c340ace0d379590d98b5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3547,14 +3547,14 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>
|
||||
|
||||
[[package]]
|
||||
name = "jupyterlab"
|
||||
version = "4.3.6"
|
||||
version = "4.3.5"
|
||||
description = "JupyterLab computational environment"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["runtime"]
|
||||
files = [
|
||||
{file = "jupyterlab-4.3.6-py3-none-any.whl", hash = "sha256:fc9eb0455562a56a9bd6d2977cf090842f321fa1a298fcee9bf8c19de353d5fd"},
|
||||
{file = "jupyterlab-4.3.6.tar.gz", hash = "sha256:2900ffdbfca9ed37c4ad7fdda3eb76582fd945d46962af3ac64741ae2d6b2ff4"},
|
||||
{file = "jupyterlab-4.3.5-py3-none-any.whl", hash = "sha256:571bbdee20e4c5321ab5195bc41cf92a75a5cff886be5e57ce78dfa37a5e9fdb"},
|
||||
{file = "jupyterlab-4.3.5.tar.gz", hash = "sha256:c779bf72ced007d7d29d5bcef128e7fdda96ea69299e19b04a43635a7d641f9d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4251,14 +4251,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "modal"
|
||||
version = "0.73.102"
|
||||
version = "0.73.98"
|
||||
description = "Python client library for Modal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main", "evaluation"]
|
||||
files = [
|
||||
{file = "modal-0.73.102-py3-none-any.whl", hash = "sha256:26151ef6164e0b93b0d1961f73d5a715deb72f23e2641215f5410cf58bf403d3"},
|
||||
{file = "modal-0.73.102.tar.gz", hash = "sha256:198876cf94ff13633283e251d8b37cc1f1bb5e27a7aa547e02072def1f29b66e"},
|
||||
{file = "modal-0.73.98-py3-none-any.whl", hash = "sha256:a49cd5f5b46d1a6c6a0d528618d3cbb73ac2908e199716590ec3a5275d79ed98"},
|
||||
{file = "modal-0.73.98.tar.gz", hash = "sha256:817f73c222fa39a16d6888a92eb7a6847ecae574e44ef04e2dce5e534bdd2df9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4670,19 +4670,19 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "notebook"
|
||||
version = "7.3.3"
|
||||
version = "7.3.2"
|
||||
description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["runtime"]
|
||||
files = [
|
||||
{file = "notebook-7.3.3-py3-none-any.whl", hash = "sha256:b193df0878956562d5171c8e25c9252b8e86c9fcc16163b8ee3fe6c5e3f422f7"},
|
||||
{file = "notebook-7.3.3.tar.gz", hash = "sha256:707a313fb882d35f921989eb3d204de942ed5132a44e4aa1fe0e8f24bb9dc25d"},
|
||||
{file = "notebook-7.3.2-py3-none-any.whl", hash = "sha256:e5f85fc59b69d3618d73cf27544418193ff8e8058d5bf61d315ce4f473556288"},
|
||||
{file = "notebook-7.3.2.tar.gz", hash = "sha256:705e83a1785f45b383bf3ee13cb76680b92d24f56fb0c7d2136fe1d850cd3ca8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
jupyter-server = ">=2.4.0,<3"
|
||||
jupyterlab = ">=4.3.6,<4.4"
|
||||
jupyterlab = ">=4.3.4,<4.4"
|
||||
jupyterlab-server = ">=2.27.1,<3"
|
||||
notebook-shim = ">=0.2,<0.3"
|
||||
tornado = ">=6.2.0"
|
||||
@@ -6947,30 +6947,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.11.0"
|
||||
version = "0.9.10"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev", "evaluation"]
|
||||
files = [
|
||||
{file = "ruff-0.11.0-py3-none-linux_armv6l.whl", hash = "sha256:dc67e32bc3b29557513eb7eeabb23efdb25753684b913bebb8a0c62495095acb"},
|
||||
{file = "ruff-0.11.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:38c23fd9bdec4eb437b4c1e3595905a0a8edfccd63a790f818b28c78fe345639"},
|
||||
{file = "ruff-0.11.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7c8661b0be91a38bd56db593e9331beaf9064a79028adee2d5f392674bbc5e88"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6c0e8d3d2db7e9f6efd884f44b8dc542d5b6b590fc4bb334fdbc624d93a29a2"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3c3156d3f4b42e57247275a0a7e15a851c165a4fc89c5e8fa30ea6da4f7407b8"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:490b1e147c1260545f6d041c4092483e3f6d8eba81dc2875eaebcf9140b53905"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1bc09a7419e09662983b1312f6fa5dab829d6ab5d11f18c3760be7ca521c9329"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bcfa478daf61ac8002214eb2ca5f3e9365048506a9d52b11bea3ecea822bb844"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6fbb2aed66fe742a6a3a0075ed467a459b7cedc5ae01008340075909d819df1e"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92c0c1ff014351c0b0cdfdb1e35fa83b780f1e065667167bb9502d47ca41e6db"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e4fd5ff5de5f83e0458a138e8a869c7c5e907541aec32b707f57cf9a5e124445"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:96bc89a5c5fd21a04939773f9e0e276308be0935de06845110f43fd5c2e4ead7"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a9352b9d767889ec5df1483f94870564e8102d4d7e99da52ebf564b882cdc2c7"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:049a191969a10897fe052ef9cc7491b3ef6de79acd7790af7d7897b7a9bfbcb6"},
|
||||
{file = "ruff-0.11.0-py3-none-win32.whl", hash = "sha256:3191e9116b6b5bbe187447656f0c8526f0d36b6fd89ad78ccaad6bdc2fad7df2"},
|
||||
{file = "ruff-0.11.0-py3-none-win_amd64.whl", hash = "sha256:c58bfa00e740ca0a6c43d41fb004cd22d165302f360aaa56f7126d544db31a21"},
|
||||
{file = "ruff-0.11.0-py3-none-win_arm64.whl", hash = "sha256:868364fc23f5aa122b00c6f794211e85f7e78f5dffdf7c590ab90b8c4e69b657"},
|
||||
{file = "ruff-0.11.0.tar.gz", hash = "sha256:e55c620690a4a7ee6f1cccb256ec2157dc597d109400ae75bbf944fc9d6462e2"},
|
||||
{file = "ruff-0.9.10-py3-none-linux_armv6l.whl", hash = "sha256:eb4d25532cfd9fe461acc83498361ec2e2252795b4f40b17e80692814329e42d"},
|
||||
{file = "ruff-0.9.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:188a6638dab1aa9bb6228a7302387b2c9954e455fb25d6b4470cb0641d16759d"},
|
||||
{file = "ruff-0.9.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5284dcac6b9dbc2fcb71fdfc26a217b2ca4ede6ccd57476f52a587451ebe450d"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47678f39fa2a3da62724851107f438c8229a3470f533894b5568a39b40029c0c"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99713a6e2766b7a17147b309e8c915b32b07a25c9efd12ada79f217c9c778b3e"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524ee184d92f7c7304aa568e2db20f50c32d1d0caa235d8ddf10497566ea1a12"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df92aeac30af821f9acf819fc01b4afc3dfb829d2782884f8739fb52a8119a16"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de42e4edc296f520bb84954eb992a07a0ec5a02fecb834498415908469854a52"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d257f95b65806104b6b1ffca0ea53f4ef98454036df65b1eda3693534813ecd1"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60dec7201c0b10d6d11be00e8f2dbb6f40ef1828ee75ed739923799513db24c"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d838b60007da7a39c046fcdd317293d10b845001f38bcb55ba766c3875b01e43"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ccaf903108b899beb8e09a63ffae5869057ab649c1e9231c05ae354ebc62066c"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f9567d135265d46e59d62dc60c0bfad10e9a6822e231f5b24032dba5a55be6b5"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5f202f0d93738c28a89f8ed9eaba01b7be339e5d8d642c994347eaa81c6d75b8"},
|
||||
{file = "ruff-0.9.10-py3-none-win32.whl", hash = "sha256:bfb834e87c916521ce46b1788fbb8484966e5113c02df216680102e9eb960029"},
|
||||
{file = "ruff-0.9.10-py3-none-win_amd64.whl", hash = "sha256:f2160eeef3031bf4b17df74e307d4c5fb689a6f3a26a2de3f7ef4044e3c484f1"},
|
||||
{file = "ruff-0.9.10-py3-none-win_arm64.whl", hash = "sha256:5fd804c0327a5e5ea26615550e706942f348b197d5475ff34c19733aee4b2e69"},
|
||||
{file = "ruff-0.9.10.tar.gz", hash = "sha256:9bacb735d7bada9cfb0f2c227d3658fc443d90a727b47f206fb33f52f3c0eac7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -9056,4 +9056,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "9b74f62a4afa719a1f7167e0b3b45cdaf282c2e18fd2931da91c0f1b22776178"
|
||||
content-hash = "6a644bc65782a717a49718496bd279ecb888807ec625d992af4448cc5d9271c1"
|
||||
|
||||
@@ -80,7 +80,7 @@ daytona-sdk = "0.10.2"
|
||||
python-json-logger = "^3.2.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.11.0"
|
||||
ruff = "0.9.10"
|
||||
mypy = "1.15.0"
|
||||
pre-commit = "4.1.0"
|
||||
build = "*"
|
||||
|
||||
@@ -9,7 +9,6 @@ from openhands.events.action import (
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
MessageAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.action import ActionConfirmationStatus
|
||||
from openhands.events.action.files import FileEditSource, FileReadSource
|
||||
@@ -357,18 +356,6 @@ def test_file_ohaci_edit_action_legacy_serialization():
|
||||
assert event_dict['args']['end'] == -1
|
||||
|
||||
|
||||
def test_agent_microagent_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'recall',
|
||||
'args': {
|
||||
'query': 'What is the capital of France?',
|
||||
'thought': 'I need to find information about France',
|
||||
'recall_type': 'knowledge',
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, RecallAction)
|
||||
|
||||
|
||||
def test_file_read_action_legacy_serialization():
|
||||
original_action_dict = {
|
||||
'action': 'read',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@@ -14,16 +14,12 @@ from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@@ -51,36 +47,17 @@ def mock_agent():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
mock = MagicMock(
|
||||
spec=EventStream,
|
||||
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
|
||||
)
|
||||
mock = MagicMock(spec=EventStream)
|
||||
mock.get_latest_event_id.return_value = 0
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_event_stream():
|
||||
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
|
||||
return event_stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime() -> Runtime:
|
||||
runtime = MagicMock(
|
||||
return MagicMock(
|
||||
spec=Runtime,
|
||||
event_stream=test_event_stream,
|
||||
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
|
||||
)
|
||||
return runtime
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory() -> Memory:
|
||||
memory = MagicMock(
|
||||
spec=Memory,
|
||||
event_stream=test_event_stream,
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -91,7 +68,6 @@ def mock_status_callback():
|
||||
async def send_event_to_controller(controller, event):
|
||||
await controller._on_event(event)
|
||||
await asyncio.sleep(0.1)
|
||||
controller._pending_action = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -164,8 +140,10 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
async def test_run_controller_with_fatal_error():
|
||||
config = AppConfig()
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent = MagicMock(spec=Agent)
|
||||
@@ -185,23 +163,10 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
if isinstance(event, CmdRunAction):
|
||||
error_obs = ErrorObservation('You messed around with Jim')
|
||||
error_obs._cause = event.id
|
||||
test_event_stream.add_event(error_obs, EventSource.USER)
|
||||
event_stream.add_event(error_obs, EventSource.USER)
|
||||
|
||||
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = test_event_stream
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
@@ -210,20 +175,22 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
print(f'state: {state}')
|
||||
events = list(test_event_stream.get_events())
|
||||
events = list(event_stream.get_events())
|
||||
print(f'event_stream: {events}')
|
||||
assert state.iteration == 3
|
||||
assert state.iteration == 4
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
async def test_run_controller_stop_with_stuck():
|
||||
config = AppConfig()
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
|
||||
def agent_step_fn(state):
|
||||
@@ -242,23 +209,10 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
'Non fatal error here to trigger loop'
|
||||
)
|
||||
non_fatal_error_obs._cause = event.id
|
||||
test_event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
||||
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = test_event_stream
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
@@ -267,17 +221,16 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
events = list(test_event_stream.get_events())
|
||||
events = list(event_stream.get_events())
|
||||
print(f'state: {state}')
|
||||
for i, event in enumerate(events):
|
||||
print(f'event {i}: {event_to_dict(event)}')
|
||||
|
||||
assert state.iteration == 3
|
||||
assert state.iteration == 4
|
||||
assert len(events) == 11
|
||||
# check the eventstream have 4 pairs of repeated actions and observations
|
||||
repeating_actions_and_observations = events[4:12]
|
||||
repeating_actions_and_observations = events[2:10]
|
||||
for action, observation in zip(
|
||||
repeating_actions_and_observations[0::2],
|
||||
repeating_actions_and_observations[1::2],
|
||||
@@ -557,13 +510,12 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_max_iterations_has_metrics(
|
||||
test_event_stream, mock_memory
|
||||
):
|
||||
async def test_run_controller_max_iterations_has_metrics():
|
||||
config = AppConfig(
|
||||
max_iterations=3,
|
||||
)
|
||||
event_stream = test_event_stream
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
@@ -594,17 +546,6 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
@@ -612,7 +553,6 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
assert state.iteration == 3
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
@@ -690,7 +630,7 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
mock_agent, mock_runtime
|
||||
):
|
||||
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON"""
|
||||
|
||||
@@ -716,20 +656,6 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_agent.step = step_state.step
|
||||
mock_agent.config = AgentConfig()
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
@@ -739,7 +665,6 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
@@ -766,7 +691,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
mock_agent, mock_runtime
|
||||
):
|
||||
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
|
||||
|
||||
@@ -777,7 +702,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
def step(self, state: State):
|
||||
# If the state has more than one message and we haven't errored yet,
|
||||
# throw the context window exceeded error
|
||||
if len(state.history) > 3 and not self.has_errored:
|
||||
if len(state.history) > 1 and not self.has_errored:
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
model='',
|
||||
@@ -793,19 +718,6 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
mock_agent.config = AgentConfig()
|
||||
mock_agent.config.enable_history_truncation = False
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
microagent_obs._cause = event.id
|
||||
test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
test_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
mock_runtime.event_stream = test_event_stream
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
@@ -815,7 +727,6 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
@@ -840,51 +751,6 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
assert step_state.has_errored
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_memory_error(test_event_stream):
|
||||
config = AppConfig()
|
||||
event_stream = test_event_stream
|
||||
|
||||
# Create a propert agent that returns an action without an ID
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
|
||||
# Create a real action to return from the mocked step function
|
||||
def agent_step_fn(state):
|
||||
return MessageAction(content='Agent returned a message')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Create a real Memory instance
|
||||
memory = Memory(event_stream=event_stream, sid='test-memory')
|
||||
|
||||
# Patch the _find_microagent_knowledge method to raise our test exception
|
||||
def mock_find_microagent_knowledge(*args, **kwargs):
|
||||
raise RuntimeError('Test memory error')
|
||||
|
||||
with patch.object(
|
||||
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
|
||||
):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: RuntimeError'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_metrics_copy():
|
||||
# Setup
|
||||
@@ -985,56 +851,3 @@ async def test_action_metrics_copy():
|
||||
assert last_action.llm_metrics.accumulated_cost == 0.07
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_user_message_with_identical_content():
|
||||
"""
|
||||
Test that _first_user_message correctly identifies the first user message
|
||||
even when multiple messages have identical content but different IDs.
|
||||
|
||||
The issue we're checking is that the comparison (action == self._first_user_message())
|
||||
should correctly differentiate between messages with the same content but different IDs.
|
||||
"""
|
||||
# Create a real event stream for this test
|
||||
event_stream = EventStream(sid='test', file_store=InMemoryFileStore({}))
|
||||
|
||||
# Create an agent controller
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create and add the first user message
|
||||
first_message = MessageAction(content='Hello, this is a test message')
|
||||
first_message._source = EventSource.USER
|
||||
event_stream.add_event(first_message, EventSource.USER)
|
||||
|
||||
# Create and add a second user message with identical content
|
||||
second_message = MessageAction(content='Hello, this is a test message')
|
||||
second_message._source = EventSource.USER
|
||||
event_stream.add_event(second_message, EventSource.USER)
|
||||
|
||||
# Verify that _first_user_message returns the first message
|
||||
first_user_message = controller._first_user_message()
|
||||
assert first_user_message is not None
|
||||
assert first_user_message.id == first_message.id # Check IDs match
|
||||
assert first_user_message.id != second_message.id # Different IDs
|
||||
assert first_user_message == first_message == second_message # dataclass equality
|
||||
|
||||
# Test the comparison used in the actual code
|
||||
assert first_message == first_user_message # This should be True
|
||||
assert (
|
||||
second_message.id != first_user_message.id
|
||||
) # This should be False, but may be True if there's a bug
|
||||
|
||||
await controller.close()
|
||||
|
||||
@@ -17,13 +17,8 @@ from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@@ -80,25 +75,6 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
initial_state=parent_state,
|
||||
)
|
||||
|
||||
# Setup Memory to catch RecallActions
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_memory.event_stream = mock_event_stream
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
# create a RecallObservation
|
||||
microagent_observation = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
content='Found info',
|
||||
)
|
||||
microagent_observation._cause = event.id # ignore attr-defined warning
|
||||
mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
mock_memory.on_event = on_event
|
||||
mock_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory
|
||||
)
|
||||
|
||||
# Setup a delegate action from the parent
|
||||
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
|
||||
mock_parent_agent.step.return_value = delegate_action
|
||||
@@ -111,16 +87,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
# Give time for the async step() to execute
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Verify that a RecallObservation was added to the event stream
|
||||
events = list(mock_event_stream.get_events())
|
||||
assert (
|
||||
mock_event_stream.get_latest_event_id() == 3
|
||||
) # Microagents and AgentChangeState
|
||||
|
||||
# a RecallObservation and an AgentDelegateAction should be in the list
|
||||
assert any(isinstance(event, RecallObservation) for event in events)
|
||||
assert any(isinstance(event, AgentDelegateAction) for event in events)
|
||||
|
||||
# The parent should receive step() from that event
|
||||
# Verify that a delegate agent controller is created
|
||||
assert (
|
||||
parent_controller.delegate is not None
|
||||
|
||||
@@ -6,11 +6,9 @@ from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AppConfig, LLMConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.events import EventStream, EventStreamSubscriber
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
@@ -24,24 +22,18 @@ def mock_agent():
|
||||
llm = MagicMock(spec=LLM)
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
llm_config = MagicMock(spec=LLMConfig)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
|
||||
# Configure the LLM config
|
||||
llm_config.model = 'test-model'
|
||||
llm_config.base_url = 'http://test'
|
||||
llm_config.max_message_chars = 1000
|
||||
|
||||
# Configure the agent config
|
||||
agent_config.disabled_microagents = []
|
||||
|
||||
# Set up the chain of mocks
|
||||
llm.metrics = metrics
|
||||
llm.config = llm_config
|
||||
agent.llm = llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
|
||||
return agent
|
||||
|
||||
@@ -86,11 +78,7 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
self.test_initial_state = state
|
||||
super().set_initial_state(*args, state=state, **kwargs)
|
||||
|
||||
# Create a real Memory instance with the mock event stream
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession
|
||||
# Patch AgentController and State.restore_from_session to fail
|
||||
with patch(
|
||||
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
||||
), patch(
|
||||
@@ -99,7 +87,7 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
), patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
side_effect=Exception('No state found'),
|
||||
), patch('openhands.server.session.agent_session.Memory', return_value=memory):
|
||||
):
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=AppConfig(),
|
||||
@@ -108,18 +96,12 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
)
|
||||
|
||||
# Verify EventStream.subscribe was called with correct parameters
|
||||
mock_event_stream.subscribe.assert_any_call(
|
||||
mock_event_stream.subscribe.assert_called_with(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER,
|
||||
session.controller.on_event,
|
||||
session.controller.id,
|
||||
)
|
||||
|
||||
mock_event_stream.subscribe.assert_any_call(
|
||||
EventStreamSubscriber.MEMORY,
|
||||
session.memory.on_event,
|
||||
session.controller.id,
|
||||
)
|
||||
|
||||
# Verify set_initial_state was called once with None as state
|
||||
assert session.controller.set_initial_state_call_count == 1
|
||||
assert session.controller.test_initial_state is None
|
||||
@@ -177,10 +159,7 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
self.test_initial_state = state
|
||||
super().set_initial_state(*args, state=state, **kwargs)
|
||||
|
||||
# create a mock Memory
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
|
||||
# Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession
|
||||
# Patch AgentController and State.restore_from_session to succeed
|
||||
with patch(
|
||||
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
||||
), patch(
|
||||
@@ -189,7 +168,7 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
), patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
return_value=mock_restored_state,
|
||||
), patch('openhands.server.session.agent_session.Memory', mock_memory):
|
||||
):
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=AppConfig(),
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Tests for the Brave Search functionality."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import AppConfig, SearchConfig
|
||||
from openhands.events.action import SearchAction
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.search_engine import SearchEngineObservation
|
||||
from openhands.runtime.search_engine.brave_search import search
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock config with search enabled."""
|
||||
config = AppConfig()
|
||||
config.search = SearchConfig(
|
||||
enabled=True,
|
||||
api_key="test_key",
|
||||
api_url="https://test.url"
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_api():
|
||||
"""Create a mock query_api function."""
|
||||
with patch("openhands.runtime.search_engine.brave_search.query_api") as mock:
|
||||
mock.return_value = SearchEngineObservation(
|
||||
query="test query",
|
||||
content="test content"
|
||||
)
|
||||
yield mock
|
||||
|
||||
|
||||
def test_search_disabled(mock_query_api):
|
||||
"""Test that search returns error when disabled."""
|
||||
config = AppConfig()
|
||||
config.search = SearchConfig(enabled=False)
|
||||
action = SearchAction(query="test query")
|
||||
|
||||
result = search(action, config)
|
||||
assert isinstance(result, ErrorObservation)
|
||||
assert "not enabled" in result.content
|
||||
mock_query_api.assert_not_called()
|
||||
|
||||
|
||||
def test_search_no_api_key(mock_query_api):
|
||||
"""Test that search returns error when API key is not set."""
|
||||
config = AppConfig()
|
||||
config.search = SearchConfig(enabled=True)
|
||||
action = SearchAction(query="test query")
|
||||
|
||||
result = search(action, config)
|
||||
assert isinstance(result, ErrorObservation)
|
||||
assert "API key not configured" in result.content
|
||||
mock_query_api.assert_not_called()
|
||||
|
||||
|
||||
def test_search_empty_query(mock_query_api, mock_config):
|
||||
"""Test that search returns error when query is empty."""
|
||||
action = SearchAction(query="")
|
||||
|
||||
result = search(action, mock_config)
|
||||
assert isinstance(result, ErrorObservation)
|
||||
assert "must be a non-empty string" in result.content
|
||||
mock_query_api.assert_not_called()
|
||||
|
||||
|
||||
def test_search_success(mock_query_api, mock_config):
|
||||
"""Test that search returns results when everything is configured correctly."""
|
||||
action = SearchAction(query="test query")
|
||||
|
||||
result = search(action, mock_config)
|
||||
assert isinstance(result, SearchEngineObservation)
|
||||
assert result.query == "test query"
|
||||
assert result.content == "test content"
|
||||
mock_query_api.assert_called_once_with(
|
||||
query="test query",
|
||||
API_KEY="test_key",
|
||||
BRAVE_SEARCH_URL="https://test.url"
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -17,18 +17,13 @@ def mock_runtime():
|
||||
with patch('openhands.core.cli.create_runtime') as mock_create_runtime:
|
||||
mock_runtime_instance = AsyncMock()
|
||||
# Mock the event stream with proper async methods
|
||||
mock_event_stream = AsyncMock()
|
||||
mock_event_stream.subscribe = AsyncMock()
|
||||
mock_event_stream.add_event = AsyncMock()
|
||||
mock_event_stream.get_events = AsyncMock(return_value=[])
|
||||
mock_event_stream.get_latest_event_id = AsyncMock(return_value=0)
|
||||
mock_runtime_instance.event_stream = mock_event_stream
|
||||
mock_runtime_instance.event_stream = AsyncMock()
|
||||
mock_runtime_instance.event_stream.subscribe = AsyncMock()
|
||||
mock_runtime_instance.event_stream.add_event = AsyncMock()
|
||||
# Mock connect method to return immediately
|
||||
mock_runtime_instance.connect = AsyncMock()
|
||||
# Ensure status_callback is None
|
||||
mock_runtime_instance.status_callback = None
|
||||
# Mock get_microagents_from_selected_repo
|
||||
mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[])
|
||||
mock_create_runtime.return_value = mock_runtime_instance
|
||||
yield mock_runtime_instance
|
||||
|
||||
@@ -37,16 +32,6 @@ def mock_runtime():
|
||||
def mock_agent():
|
||||
with patch('openhands.core.cli.create_agent') as mock_create_agent:
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_instance.name = 'test-agent'
|
||||
mock_agent_instance.llm = AsyncMock()
|
||||
mock_agent_instance.llm.config = AsyncMock()
|
||||
mock_agent_instance.llm.config.model = 'test-model'
|
||||
mock_agent_instance.llm.config.base_url = 'http://test'
|
||||
mock_agent_instance.llm.config.max_message_chars = 1000
|
||||
mock_agent_instance.config = AsyncMock()
|
||||
mock_agent_instance.config.disabled_microagents = []
|
||||
mock_agent_instance.sandbox_plugins = []
|
||||
mock_agent_instance.prompt_manager = AsyncMock()
|
||||
mock_create_agent.return_value = mock_agent_instance
|
||||
yield mock_agent_instance
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ from litellm import ChatCompletionMessageToolCall
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.agenthub.codeact_agent.function_calling import (
|
||||
BrowserTool,
|
||||
CmdRunTool,
|
||||
IPythonTool,
|
||||
LLMBasedFileEditTool,
|
||||
StrReplaceEditorTool,
|
||||
WebReadTool,
|
||||
create_cmd_run_tool,
|
||||
create_str_replace_editor_tool,
|
||||
get_tools,
|
||||
response_to_actions,
|
||||
)
|
||||
@@ -25,7 +25,6 @@ from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.events.action import (
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
SearchAction,
|
||||
)
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation.commands import (
|
||||
@@ -101,30 +100,25 @@ def test_get_tools_with_options():
|
||||
codeact_enable_browsing=True,
|
||||
codeact_enable_jupyter=True,
|
||||
codeact_enable_llm_editor=True,
|
||||
codeact_enable_search_engine=True,
|
||||
)
|
||||
tool_names = [tool['function']['name'] for tool in tools]
|
||||
assert 'browser' in tool_names
|
||||
assert 'execute_ipython_cell' in tool_names
|
||||
assert 'edit_file' in tool_names
|
||||
assert 'search_engine' in tool_names
|
||||
|
||||
# Test with all options disabled
|
||||
tools = get_tools(
|
||||
codeact_enable_browsing=False,
|
||||
codeact_enable_jupyter=False,
|
||||
codeact_enable_llm_editor=False,
|
||||
codeact_enable_search_engine=False,
|
||||
)
|
||||
tool_names = [tool['function']['name'] for tool in tools]
|
||||
assert 'browser' not in tool_names
|
||||
assert 'execute_ipython_cell' not in tool_names
|
||||
assert 'edit_file' not in tool_names
|
||||
assert 'search_engine' not in tool_names
|
||||
|
||||
|
||||
def test_cmd_run_tool():
|
||||
CmdRunTool = create_cmd_run_tool()
|
||||
assert CmdRunTool['type'] == 'function'
|
||||
assert CmdRunTool['function']['name'] == 'execute_bash'
|
||||
assert 'command' in CmdRunTool['function']['parameters']['properties']
|
||||
@@ -155,7 +149,6 @@ def test_llm_based_file_edit_tool():
|
||||
|
||||
|
||||
def test_str_replace_editor_tool():
|
||||
StrReplaceEditorTool = create_str_replace_editor_tool()
|
||||
assert StrReplaceEditorTool['type'] == 'function'
|
||||
assert StrReplaceEditorTool['function']['name'] == 'str_replace_editor'
|
||||
|
||||
@@ -181,15 +174,6 @@ def test_web_read_tool():
|
||||
assert WebReadTool['function']['parameters']['required'] == ['url']
|
||||
|
||||
|
||||
def test_search_engine_tool():
|
||||
from openhands.agenthub.codeact_agent.tools import SearchEngineTool
|
||||
|
||||
assert SearchEngineTool['type'] == 'function'
|
||||
assert SearchEngineTool['function']['name'] == 'search_engine'
|
||||
assert 'query' in SearchEngineTool['function']['parameters']['properties']
|
||||
assert SearchEngineTool['function']['parameters']['required'] == ['query']
|
||||
|
||||
|
||||
def test_browser_tool():
|
||||
assert BrowserTool['type'] == 'function'
|
||||
assert BrowserTool['function']['name'] == 'browser'
|
||||
@@ -226,42 +210,6 @@ def test_browser_tool():
|
||||
assert 'description' in BrowserTool['function']['parameters']['properties']['code']
|
||||
|
||||
|
||||
def test_response_to_actions_search_engine():
|
||||
# Test response with search engine tool call
|
||||
from litellm import ChatCompletionMessageToolCall, Choices, Message, ModelResponse
|
||||
|
||||
mock_response = ModelResponse(
|
||||
id='mock_id',
|
||||
choices=[
|
||||
Choices(
|
||||
message=Message(
|
||||
content='Let me search for that',
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id='tool_call_10',
|
||||
function={
|
||||
'name': 'search_engine',
|
||||
'arguments': '{"query": "test query"}',
|
||||
},
|
||||
type='function',
|
||||
)
|
||||
],
|
||||
role='assistant',
|
||||
),
|
||||
index=0,
|
||||
finish_reason='tool_calls',
|
||||
)
|
||||
],
|
||||
model='mock_model',
|
||||
usage={'total_tokens': 100},
|
||||
)
|
||||
|
||||
actions = response_to_actions(mock_response)
|
||||
assert len(actions) == 1
|
||||
assert isinstance(actions[0], SearchAction)
|
||||
assert actions[0].query == 'test query'
|
||||
|
||||
|
||||
def test_response_to_actions_invalid_tool():
|
||||
# Test response with invalid tool call
|
||||
mock_response = Mock()
|
||||
@@ -288,11 +236,7 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
mock_response.choices[0].message.content = 'Task completed'
|
||||
mock_response.choices[0].message.tool_calls = []
|
||||
|
||||
mock_config = Mock()
|
||||
mock_config.model = 'mock_model'
|
||||
|
||||
llm = Mock()
|
||||
llm.config = mock_config
|
||||
llm.completion = Mock(return_value=mock_response)
|
||||
llm.is_function_calling_active = Mock(return_value=True) # Enable function calling
|
||||
llm.is_caching_prompt_active = Mock(return_value=False)
|
||||
@@ -316,28 +260,6 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
assert action.content == 'Task completed'
|
||||
|
||||
|
||||
def test_correct_tool_description_loaded_based_on_model_name(mock_state: State):
|
||||
"""Tests that the simplified tool descriptions are loaded for specific models."""
|
||||
o3_mock_config = Mock()
|
||||
o3_mock_config.model = 'mock_o3_model'
|
||||
|
||||
llm = Mock()
|
||||
llm.config = o3_mock_config
|
||||
|
||||
agent = CodeActAgent(llm=llm, config=AgentConfig())
|
||||
for tool in agent.tools:
|
||||
# Assert all descriptions have less than 1024 characters
|
||||
assert len(tool['function']['description']) < 1024
|
||||
|
||||
sonnet_mock_config = Mock()
|
||||
sonnet_mock_config.model = 'mock_sonnet_model'
|
||||
|
||||
llm.config = sonnet_mock_config
|
||||
agent = CodeActAgent(llm=llm, config=AgentConfig())
|
||||
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
|
||||
assert any(len(tool['function']['description']) > 1024 for tool in agent.tools)
|
||||
|
||||
|
||||
def test_mismatched_tool_call_events(mock_state: State):
|
||||
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages."""
|
||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
||||
@@ -447,3 +369,9 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
|
||||
# Fifth message only has ImageContent, no TextContent to modify
|
||||
assert len(enhanced_messages[5].content) == 1
|
||||
assert isinstance(enhanced_messages[5].content[0], ImageContent)
|
||||
|
||||
# Verify prompt manager methods were called as expected
|
||||
assert agent.prompt_manager.add_examples_to_initial_message.call_count == 1
|
||||
assert (
|
||||
agent.prompt_manager.enhance_message.call_count == 5
|
||||
) # Called for each user message
|
||||
|
||||
@@ -32,7 +32,6 @@ def _patch_store():
|
||||
'selected_repository': 'foobar',
|
||||
'conversation_id': 'some_conversation_id',
|
||||
'github_user_id': '12345',
|
||||
'user_id': '12345',
|
||||
'created_at': '2025-01-01T00:00:00+00:00',
|
||||
'last_updated_at': '2025-01-01T00:01:00+00:00',
|
||||
}
|
||||
|
||||
@@ -1,29 +1,16 @@
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import (
|
||||
Event,
|
||||
EventSource,
|
||||
FileEditSource,
|
||||
FileReadSource,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
@@ -35,45 +22,14 @@ from openhands.events.observation.files import FileEditObservation, FileReadObse
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.memory.conversation_memory import ConversationMemory
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_config():
|
||||
return AgentConfig(
|
||||
enable_prompt_extensions=True,
|
||||
enable_som_visual_browsing=True,
|
||||
disabled_microagents=['disabled_agent'],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_memory(agent_config):
|
||||
def conversation_memory():
|
||||
prompt_manager = MagicMock(spec=PromptManager)
|
||||
prompt_manager.get_system_message.return_value = 'System message'
|
||||
prompt_manager.build_workspace_context.return_value = (
|
||||
'Formatted repository and runtime info'
|
||||
)
|
||||
|
||||
# Make build_microagent_info return the actual content from the triggered agents
|
||||
def build_microagent_info(triggered_agents):
|
||||
if not triggered_agents:
|
||||
return ''
|
||||
return '\n'.join(agent.content for agent in triggered_agents)
|
||||
|
||||
prompt_manager.build_microagent_info.side_effect = build_microagent_info
|
||||
return ConversationMemory(agent_config, prompt_manager)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_dir(tmp_path):
|
||||
# Copy contents from "openhands/agenthub/codeact_agent" to the temp directory
|
||||
shutil.copytree(
|
||||
'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True
|
||||
)
|
||||
|
||||
# Return the temporary directory path
|
||||
return tmp_path
|
||||
return ConversationMemory(prompt_manager)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -352,40 +308,6 @@ def test_process_events_with_user_reject_observation(conversation_memory):
|
||||
assert '[Last action has been rejected by the user]' in result.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_empty_environment_info(conversation_memory):
|
||||
"""Test that empty environment info observations return an empty list of messages without calling build_workspace_context."""
|
||||
# Create a RecallObservation with empty info
|
||||
|
||||
empty_obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='',
|
||||
repo_directory='',
|
||||
repo_instructions='',
|
||||
runtime_hosts={},
|
||||
additional_agent_instructions='',
|
||||
microagent_knowledge=[],
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[empty_obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Should only contain the initial system message
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# Verify that build_workspace_context was NOT called since all input values were empty
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
|
||||
|
||||
def test_process_events_with_function_calling_observation(conversation_memory):
|
||||
mock_response = {
|
||||
'id': 'mock_id',
|
||||
@@ -524,529 +446,3 @@ def test_apply_prompt_caching(conversation_memory):
|
||||
assert messages[1].content[0].cache_prompt is False
|
||||
assert messages[2].content[0].cache_prompt is False
|
||||
assert messages[3].content[0].cache_prompt is True
|
||||
|
||||
|
||||
def test_process_events_with_environment_microagent_observation(conversation_memory):
|
||||
"""Test processing a RecallObservation with ENVIRONMENT info type."""
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
repo_instructions='# Test Repository\nThis is a test repository.',
|
||||
runtime_hosts={'localhost': 8080},
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == 'Formatted repository and runtime info'
|
||||
|
||||
# Verify the prompt_manager was called with the correct parameters
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_called_once()
|
||||
call_args = conversation_memory.prompt_manager.build_workspace_context.call_args[1]
|
||||
assert isinstance(call_args['repository_info'], RepositoryInfo)
|
||||
assert call_args['repository_info'].repo_name == 'test-repo'
|
||||
assert call_args['repository_info'].repo_directory == '/path/to/repo'
|
||||
assert isinstance(call_args['runtime_info'], RuntimeInfo)
|
||||
assert call_args['runtime_info'].available_hosts == {'localhost': 8080}
|
||||
assert (
|
||||
call_args['repo_instructions']
|
||||
== '# Test Repository\nThis is a test repository.'
|
||||
)
|
||||
|
||||
|
||||
def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test processing a RecallObservation with KNOWLEDGE type."""
|
||||
microagent_knowledge = [
|
||||
MicroagentKnowledge(
|
||||
name='test_agent',
|
||||
trigger='test',
|
||||
content='This is test agent content',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='another_agent',
|
||||
trigger='another',
|
||||
content='This is another agent content',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='disabled_agent',
|
||||
trigger='disabled',
|
||||
content='This is disabled agent content',
|
||||
),
|
||||
]
|
||||
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
# Verify that disabled_agent is filtered out and enabled agents are included
|
||||
assert 'This is test agent content' in result.content[0].text
|
||||
assert 'This is another agent content' in result.content[0].text
|
||||
assert 'This is disabled agent content' not in result.content[0].text
|
||||
|
||||
# Verify the prompt_manager was called with the correct parameters
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_called_once()
|
||||
call_args = conversation_memory.prompt_manager.build_microagent_info.call_args[1]
|
||||
|
||||
# Check that disabled_agent was filtered out
|
||||
triggered_agents = call_args['triggered_agents']
|
||||
assert len(triggered_agents) == 2
|
||||
agent_names = [agent.name for agent in triggered_agents]
|
||||
assert 'test_agent' in agent_names
|
||||
assert 'another_agent' in agent_names
|
||||
assert 'disabled_agent' not in agent_names
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_extensions_disabled(
|
||||
agent_config, conversation_memory
|
||||
):
|
||||
"""Test processing a RecallObservation when prompt extensions are disabled."""
|
||||
# Modify the agent config to disable prompt extensions
|
||||
agent_config.enable_prompt_extensions = False
|
||||
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# When prompt extensions are disabled, the RecallObservation should be ignored
|
||||
assert len(messages) == 1 # Only the initial system message
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# Verify the prompt_manager was not called
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
|
||||
|
||||
def test_process_events_with_empty_microagent_knowledge(conversation_memory):
|
||||
"""Test processing a RecallObservation with empty microagent knowledge."""
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[],
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# The implementation returns an empty string and it doesn't creates a message
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# When there are no triggered agents, build_microagent_info is not called
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
|
||||
|
||||
def test_conversation_memory_processes_microagent_observation(prompt_dir):
|
||||
"""Test that ConversationMemory processes RecallObservations correctly."""
|
||||
# Create a microagent_info.j2 template file
|
||||
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
# Verify the template was correctly rendered
|
||||
{{ agent_info.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
""")
|
||||
|
||||
# Create a mock agent config
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
agent_config.enable_prompt_extensions = True
|
||||
agent_config.disabled_microagents = []
|
||||
|
||||
# Create a PromptManager
|
||||
prompt_manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Initialize ConversationMemory
|
||||
conversation_memory = ConversationMemory(
|
||||
config=agent_config, prompt_manager=prompt_manager
|
||||
)
|
||||
|
||||
# Create a RecallObservation with microagent knowledge
|
||||
microagent_observation = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='test_agent',
|
||||
trigger='test_trigger',
|
||||
content='This is triggered content for testing.',
|
||||
)
|
||||
],
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
# Process the observation
|
||||
messages = conversation_memory._process_observation(
|
||||
obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None
|
||||
)
|
||||
|
||||
# Verify the message was created correctly
|
||||
assert len(messages) == 1
|
||||
message = messages[0]
|
||||
assert message.role == 'user'
|
||||
assert len(message.content) == 1
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
|
||||
expected_text = """<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "test_trigger".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
This is triggered content for testing.
|
||||
</EXTRA_INFO>"""
|
||||
|
||||
assert message.content[0].text.strip() == expected_text.strip()
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'microagent_info.j2'))
|
||||
|
||||
|
||||
def test_conversation_memory_processes_environment_microagent_observation(prompt_dir):
|
||||
"""Test that ConversationMemory processes environment info RecallObservations correctly."""
|
||||
# Create an additional_info.j2 template file
|
||||
template_path = os.path.join(prompt_dir, 'additional_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
|
||||
{% if repository_instructions %}
|
||||
<REPOSITORY_INSTRUCTIONS>
|
||||
{{ repository_instructions }}
|
||||
</REPOSITORY_INSTRUCTIONS>
|
||||
{% endif %}
|
||||
|
||||
{% if runtime_info and runtime_info.available_hosts %}
|
||||
<RUNTIME_INFORMATION>
|
||||
The user has access to the following hosts for accessing a web application,
|
||||
each of which has a corresponding port:
|
||||
{% for host, port in runtime_info.available_hosts.items() %}
|
||||
* {{ host }} (port {{ port }})
|
||||
{% endfor %}
|
||||
</RUNTIME_INFORMATION>
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
# Create a mock agent config
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
agent_config.enable_prompt_extensions = True
|
||||
|
||||
# Create a PromptManager
|
||||
prompt_manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Initialize ConversationMemory
|
||||
conversation_memory = ConversationMemory(
|
||||
config=agent_config, prompt_manager=prompt_manager
|
||||
)
|
||||
|
||||
# Create a RecallObservation with environment info
|
||||
microagent_observation = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='owner/repo',
|
||||
repo_directory='/workspace/repo',
|
||||
repo_instructions='This repository contains important code.',
|
||||
runtime_hosts={'example.com': 8080},
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
# Process the observation
|
||||
messages = conversation_memory._process_observation(
|
||||
obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None
|
||||
)
|
||||
|
||||
# Verify the message was created correctly
|
||||
assert len(messages) == 1
|
||||
message = messages[0]
|
||||
assert message.role == 'user'
|
||||
assert len(message.content) == 1
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
|
||||
# Check that the message contains the repository info
|
||||
assert '<REPOSITORY_INFO>' in message.content[0].text
|
||||
assert 'owner/repo' in message.content[0].text
|
||||
assert '/workspace/repo' in message.content[0].text
|
||||
|
||||
# Check that the message contains the repository instructions
|
||||
assert '<REPOSITORY_INSTRUCTIONS>' in message.content[0].text
|
||||
assert 'This repository contains important code.' in message.content[0].text
|
||||
|
||||
# Check that the message contains the runtime info
|
||||
assert '<RUNTIME_INFORMATION>' in message.content[0].text
|
||||
assert 'example.com (port 8080)' in message.content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication(conversation_memory):
|
||||
"""Test that RecallObservations are properly deduplicated based on agent name.
|
||||
|
||||
The deduplication logic should keep the FIRST occurrence of each microagent
|
||||
and filter out later occurrences to avoid redundant information.
|
||||
"""
|
||||
# Create a sequence of RecallObservations with overlapping agents
|
||||
obs1 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_agent',
|
||||
trigger='python',
|
||||
content='Python best practices v1',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='git_agent',
|
||||
trigger='git',
|
||||
content='Git best practices v1',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='image_agent',
|
||||
trigger='image',
|
||||
content='Image best practices v1',
|
||||
),
|
||||
],
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_agent',
|
||||
trigger='python',
|
||||
content='Python best practices v2',
|
||||
),
|
||||
],
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
obs3 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='git_agent',
|
||||
trigger='git',
|
||||
content='Git best practices v3',
|
||||
),
|
||||
],
|
||||
content='Third retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2, obs3],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that only the first occurrence of content for each agent is included
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # system + 1 microagent, because the second and third microagents are duplicates
|
||||
microagent_messages = messages[1:] # Skip system message
|
||||
|
||||
# First microagent should include all agents since they appear here first
|
||||
assert 'Image best practices v1' in microagent_messages[0].content[0].text
|
||||
assert 'Git best practices v1' in microagent_messages[0].content[0].text
|
||||
assert 'Python best practices v1' in microagent_messages[0].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test that disabled agents are filtered out and deduplication keeps the first occurrence."""
|
||||
# Create a sequence of RecallObservations with disabled agents
|
||||
obs1 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='disabled_agent',
|
||||
trigger='disabled',
|
||||
content='Disabled agent content',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='enabled_agent',
|
||||
trigger='enabled',
|
||||
content='Enabled agent content v1',
|
||||
),
|
||||
],
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='enabled_agent',
|
||||
trigger='enabled',
|
||||
content='Enabled agent content v2',
|
||||
),
|
||||
],
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # system + 1 microagent, the second is the same "enabled_agent"
|
||||
microagent_messages = messages[1:] # Skip system message
|
||||
|
||||
# First microagent should include enabled_agent but not disabled_agent
|
||||
assert 'Disabled agent content' not in microagent_messages[0].content[0].text
|
||||
assert 'Enabled agent content v1' in microagent_messages[0].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test that empty RecallObservations are handled correctly."""
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[],
|
||||
content='Empty retrieval',
|
||||
)
|
||||
|
||||
initial_messages = [
|
||||
Message(role='system', content=[TextContent(text='System message')])
|
||||
]
|
||||
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_messages=initial_messages,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that empty RecallObservations are handled gracefully
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # system message, because an empty microagent is not added to Messages
|
||||
|
||||
|
||||
def test_has_agent_in_earlier_events(conversation_memory):
|
||||
"""Test the _has_agent_in_earlier_events helper method."""
|
||||
# Create test RecallObservations
|
||||
obs1 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='agent1',
|
||||
trigger='trigger1',
|
||||
content='Content 1',
|
||||
),
|
||||
],
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='agent2',
|
||||
trigger='trigger2',
|
||||
content='Content 2',
|
||||
),
|
||||
],
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
obs3 = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
content='Environment info',
|
||||
)
|
||||
|
||||
# Create a list with mixed event types
|
||||
events = [obs1, MessageAction(content='User message'), obs2, obs3]
|
||||
|
||||
# Test looking for existing agents
|
||||
assert conversation_memory._has_agent_in_earlier_events('agent1', 2, events) is True
|
||||
assert conversation_memory._has_agent_in_earlier_events('agent1', 3, events) is True
|
||||
assert conversation_memory._has_agent_in_earlier_events('agent1', 4, events) is True
|
||||
|
||||
# Test looking for an agent in a later position (should not find it)
|
||||
assert (
|
||||
conversation_memory._has_agent_in_earlier_events('agent2', 0, events) is False
|
||||
)
|
||||
assert (
|
||||
conversation_memory._has_agent_in_earlier_events('agent2', 1, events) is False
|
||||
)
|
||||
|
||||
# Test looking for an agent in a different microagent type (should not find it)
|
||||
assert (
|
||||
conversation_memory._has_agent_in_earlier_events('non_existent', 3, events)
|
||||
is False
|
||||
)
|
||||
|
||||
@@ -13,8 +13,7 @@ async def test_load_store():
|
||||
store = FileConversationStore(InMemoryFileStore({}))
|
||||
expected = ConversationMetadata(
|
||||
conversation_id='some-conversation-id',
|
||||
user_id='some-user-id',
|
||||
github_user_id='12345',
|
||||
github_user_id='some-user-id',
|
||||
selected_repository='some-repo',
|
||||
title="Let's talk about trains",
|
||||
)
|
||||
@@ -32,7 +31,6 @@ async def test_load_int_user_id():
|
||||
{
|
||||
'conversation_id': 'some-conversation-id',
|
||||
'github_user_id': 12345,
|
||||
'user_id': '67890',
|
||||
'selected_repository': 'some-repo',
|
||||
'title': "Let's talk about trains",
|
||||
'created_at': '2025-01-16T19:51:04.886331Z',
|
||||
@@ -43,7 +41,6 @@ async def test_load_int_user_id():
|
||||
)
|
||||
found = await store.get_metadata('some-conversation-id')
|
||||
assert found.github_user_id == '12345'
|
||||
assert found.user_id == '67890'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -64,7 +61,6 @@ async def test_search_basic():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'First conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -74,7 +70,6 @@ async def test_search_basic():
|
||||
{
|
||||
'conversation_id': 'conv2',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Second conversation',
|
||||
'created_at': '2025-01-17T19:51:04Z',
|
||||
@@ -84,7 +79,6 @@ async def test_search_basic():
|
||||
{
|
||||
'conversation_id': 'conv3',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Third conversation',
|
||||
'created_at': '2025-01-15T19:51:04Z',
|
||||
@@ -113,7 +107,6 @@ async def test_search_pagination():
|
||||
{
|
||||
'conversation_id': f'conv{i}',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': f'Conversation {i}',
|
||||
'created_at': f'2025-01-{15+i}T19:51:04Z',
|
||||
@@ -155,7 +148,6 @@ async def test_search_with_invalid_conversation():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Valid conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -184,7 +176,6 @@ async def test_get_all_metadata():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'First conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -194,7 +185,6 @@ async def test_get_all_metadata():
|
||||
{
|
||||
'conversation_id': 'conv2',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Second conversation',
|
||||
'created_at': '2025-01-17T19:51:04Z',
|
||||
|
||||
@@ -358,12 +358,12 @@ class TestStuckDetector:
|
||||
with patch('logging.Logger.warning'):
|
||||
assert stuck_detector.is_stuck(headless_mode=True) is False
|
||||
|
||||
def test_is_not_stuck_ipython_unterminated_string_error_only_two_incidents(
|
||||
def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
|
||||
self, stuck_detector: StuckDetector
|
||||
):
|
||||
state = stuck_detector.state
|
||||
self._impl_unterminated_string_error_events(
|
||||
state, random_line=False, incidents=2
|
||||
state, random_line=False, incidents=3
|
||||
)
|
||||
|
||||
with patch('logging.Logger.warning'):
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.main import run_controller
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation.agent import (
|
||||
RecallObservation,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_store():
|
||||
"""Create a temporary file store for testing."""
|
||||
return InMemoryFileStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream(file_store):
|
||||
"""Create a test event stream."""
|
||||
return EventStream(sid='test_sid', file_store=file_store)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory(event_stream):
|
||||
"""Create a test memory instance."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
memory = Memory(event_stream, 'test_sid')
|
||||
yield memory
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_dir(tmp_path):
|
||||
# Copy contents from "openhands/agenthub/codeact_agent" to the temp directory
|
||||
shutil.copytree(
|
||||
'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True
|
||||
)
|
||||
|
||||
# Return the temporary directory path
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_event_exception_handling(memory, event_stream):
|
||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||
|
||||
# Create a dummy agent for the controller
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory method to raise an exception
|
||||
with patch.object(
|
||||
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
||||
):
|
||||
state = await run_controller(
|
||||
config=AppConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: Exception'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
memory, event_stream
|
||||
):
|
||||
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
||||
|
||||
# Create a dummy agent for the controller
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory._on_workspace_context_recall to raise an exception
|
||||
with patch.object(
|
||||
memory,
|
||||
'_find_microagent_knowledge',
|
||||
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
||||
):
|
||||
state = await run_controller(
|
||||
config=AppConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: Exception'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_with_microagents():
|
||||
"""Test that Memory loads microagents from the global directory and processes microagent actions.
|
||||
|
||||
This test verifies that:
|
||||
1. Memory loads microagents from the global GLOBAL_MICROAGENTS_DIR
|
||||
2. When a microagent action with a trigger word is processed, a RecallObservation is created
|
||||
"""
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
# Initialize Memory to use the global microagents dir
|
||||
memory = Memory(
|
||||
event_stream=event_stream,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Verify microagents were loaded - at least one microagent should be loaded
|
||||
# from the global directory that's in the repo
|
||||
assert len(memory.knowledge_microagents) > 0
|
||||
|
||||
# We know 'flarglebargle' exists in the global directory
|
||||
assert 'flarglebargle' in memory.knowledge_microagents
|
||||
|
||||
# Create a microagent action with the trigger word
|
||||
microagent_action = RecallAction(
|
||||
query='Hello, flarglebargle!', recall_type=RecallType.KNOWLEDGE
|
||||
)
|
||||
|
||||
# Set the source to USER
|
||||
microagent_action._source = EventSource.USER # type: ignore[attr-defined]
|
||||
|
||||
# Mock the event_stream.add_event method
|
||||
added_events = []
|
||||
|
||||
def original_add_event(event, source):
|
||||
added_events.append((event, source))
|
||||
|
||||
event_stream.add_event = original_add_event
|
||||
|
||||
# Add the microagent action to the event stream
|
||||
event_stream.add_event(microagent_action, EventSource.USER)
|
||||
|
||||
# Clear the events list to only capture new events
|
||||
added_events.clear()
|
||||
|
||||
# Process the microagent action
|
||||
await memory._on_event(microagent_action)
|
||||
|
||||
# Verify a RecallObservation was added to the event stream
|
||||
assert len(added_events) == 1
|
||||
observation, source = added_events[0]
|
||||
assert isinstance(observation, RecallObservation)
|
||||
assert source == EventSource.ENVIRONMENT
|
||||
assert observation.recall_type == RecallType.KNOWLEDGE
|
||||
assert len(observation.microagent_knowledge) == 1
|
||||
assert observation.microagent_knowledge[0].name == 'flarglebargle'
|
||||
assert observation.microagent_knowledge[0].trigger == 'flarglebargle'
|
||||
assert 'magic word' in observation.microagent_knowledge[0].content
|
||||
|
||||
|
||||
def test_memory_repository_info(prompt_dir):
|
||||
"""Test that Memory adds repository info to RecallObservations."""
|
||||
# Create an in-memory file store and real event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
|
||||
# Create a test repo microagent first
|
||||
repo_microagent_name = 'test_repo_microagent'
|
||||
repo_microagent_content = """---
|
||||
name: test_repo
|
||||
type: repo
|
||||
agent: CodeActAgent
|
||||
---
|
||||
|
||||
REPOSITORY INSTRUCTIONS: This is a test repository.
|
||||
"""
|
||||
|
||||
# Create a temporary repo microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(
|
||||
os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'), 'w'
|
||||
) as f:
|
||||
f.write(repo_microagent_content)
|
||||
|
||||
# Patch the global microagents directory to use our test directory
|
||||
test_microagents_dir = os.path.join(prompt_dir, 'micro')
|
||||
with patch('openhands.memory.memory.GLOBAL_MICROAGENTS_DIR', test_microagents_dir):
|
||||
# Initialize Memory
|
||||
memory = Memory(
|
||||
event_stream=event_stream,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Set repository info
|
||||
memory.set_repository_info('owner/repo', '/workspace/repo')
|
||||
|
||||
# Create and add the first user message
|
||||
user_message = MessageAction(content='First user message')
|
||||
user_message._source = EventSource.USER # type: ignore[attr-defined]
|
||||
event_stream.add_event(user_message, EventSource.USER)
|
||||
|
||||
# Create and add the microagent action
|
||||
microagent_action = RecallAction(
|
||||
query='First user message', recall_type=RecallType.WORKSPACE_CONTEXT
|
||||
)
|
||||
microagent_action._source = EventSource.USER # type: ignore[attr-defined]
|
||||
event_stream.add_event(microagent_action, EventSource.USER)
|
||||
|
||||
# Give it a little time to process
|
||||
time.sleep(0.3)
|
||||
|
||||
# Get all events from the stream
|
||||
events = list(event_stream.get_events())
|
||||
|
||||
# Find the RecallObservation event
|
||||
microagent_obs_events = [
|
||||
event for event in events if isinstance(event, RecallObservation)
|
||||
]
|
||||
|
||||
# We should have at least one RecallObservation
|
||||
assert len(microagent_obs_events) > 0
|
||||
|
||||
# Get the first RecallObservation
|
||||
observation = microagent_obs_events[0]
|
||||
assert observation.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
assert observation.repo_name == 'owner/repo'
|
||||
assert observation.repo_directory == '/workspace/repo'
|
||||
assert 'This is a test repository' in observation.repo_instructions
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'))
|
||||
@@ -1,21 +1,16 @@
|
||||
from openhands.core.schema.observation import ObservationType
|
||||
from openhands.events.action.files import FileEditSource
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
CmdOutputMetadata,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
Observation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
from openhands.events.serialization import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
event_to_memory,
|
||||
event_to_trajectory,
|
||||
)
|
||||
from openhands.events.serialization.observation import observation_from_dict
|
||||
|
||||
|
||||
def serialization_deserialization(
|
||||
@@ -24,10 +19,10 @@ def serialization_deserialization(
|
||||
observation_instance = event_from_dict(original_observation_dict)
|
||||
assert isinstance(
|
||||
observation_instance, Observation
|
||||
), 'The observation instance should be an instance of Observation.'
|
||||
), 'The observation instance should be an instance of Action.'
|
||||
assert isinstance(
|
||||
observation_instance, cls
|
||||
), f'The observation instance should be an instance of {cls}.'
|
||||
), 'The observation instance should be an instance of CmdOutputObservation.'
|
||||
serialized_observation_dict = event_to_dict(observation_instance)
|
||||
serialized_observation_trajectory = event_to_trajectory(observation_instance)
|
||||
serialized_observation_memory = event_to_memory(
|
||||
@@ -241,199 +236,3 @@ def test_file_edit_observation_legacy_serialization():
|
||||
assert event_dict['extras']['old_content'] is None
|
||||
assert event_dict['extras']['new_content'] == 'new content'
|
||||
assert 'formatted_output_and_error' not in event_dict['extras']
|
||||
|
||||
|
||||
def test_microagent_observation_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'recall',
|
||||
'content': '',
|
||||
'message': 'Added workspace context',
|
||||
'extras': {
|
||||
'recall_type': 'workspace_context',
|
||||
'repo_name': 'some_repo_name',
|
||||
'repo_directory': 'some_repo_directory',
|
||||
'runtime_hosts': {'host1': 8080, 'host2': 8081},
|
||||
'repo_instructions': 'complex_repo_instructions',
|
||||
'additional_agent_instructions': 'You know it all about this runtime',
|
||||
'microagent_knowledge': [],
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, RecallObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_microagent_knowledge_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'recall',
|
||||
'content': '',
|
||||
'message': 'Added microagent knowledge',
|
||||
'extras': {
|
||||
'recall_type': 'knowledge',
|
||||
'repo_name': '',
|
||||
'repo_directory': '',
|
||||
'repo_instructions': '',
|
||||
'runtime_hosts': {},
|
||||
'additional_agent_instructions': '',
|
||||
'microagent_knowledge': [
|
||||
{
|
||||
'name': 'microagent1',
|
||||
'trigger': 'trigger1',
|
||||
'content': 'content1',
|
||||
},
|
||||
{
|
||||
'name': 'microagent2',
|
||||
'trigger': 'trigger2',
|
||||
'content': 'content2',
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, RecallObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_knowledge_microagent_serialization():
|
||||
"""Test serialization of a RecallObservation with KNOWLEDGE_MICROAGENT type."""
|
||||
# Create a RecallObservation with microagent knowledge content
|
||||
original = RecallObservation(
|
||||
content='Knowledge microagent information',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_best_practices',
|
||||
trigger='python',
|
||||
content='Always use virtual environments for Python projects.',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='git_workflow',
|
||||
trigger='git',
|
||||
content='Create a new branch for each feature or bugfix.',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.RECALL
|
||||
assert serialized['content'] == 'Knowledge microagent information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value
|
||||
assert len(serialized['extras']['microagent_knowledge']) == 2
|
||||
assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python'
|
||||
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
assert deserialized.recall_type == RecallType.KNOWLEDGE
|
||||
assert deserialized.microagent_knowledge == original.microagent_knowledge
|
||||
assert deserialized.content == original.content
|
||||
|
||||
# Check that environment info fields are empty
|
||||
assert deserialized.repo_name == ''
|
||||
assert deserialized.repo_directory == ''
|
||||
assert deserialized.repo_instructions == ''
|
||||
assert deserialized.runtime_hosts == {}
|
||||
|
||||
|
||||
def test_microagent_observation_environment_serialization():
|
||||
"""Test serialization of a RecallObservation with ENVIRONMENT type."""
|
||||
# Create a RecallObservation with environment info
|
||||
original = RecallObservation(
|
||||
content='Environment information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.RECALL
|
||||
assert serialized['content'] == 'Environment information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
|
||||
assert serialized['extras']['repo_name'] == 'OpenHands'
|
||||
assert serialized['extras']['runtime_hosts'] == {
|
||||
'127.0.0.1': 8080,
|
||||
'localhost': 5000,
|
||||
}
|
||||
assert (
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
assert deserialized.repo_name == original.repo_name
|
||||
assert deserialized.repo_directory == original.repo_directory
|
||||
assert deserialized.repo_instructions == original.repo_instructions
|
||||
assert deserialized.runtime_hosts == original.runtime_hosts
|
||||
assert (
|
||||
deserialized.additional_agent_instructions
|
||||
== original.additional_agent_instructions
|
||||
)
|
||||
# Check that knowledge microagent fields are empty
|
||||
assert deserialized.microagent_knowledge == []
|
||||
|
||||
|
||||
def test_microagent_observation_combined_serialization():
|
||||
"""Test serialization of a RecallObservation with both types of information."""
|
||||
# Create a RecallObservation with both environment and microagent info
|
||||
# Note: In practice, recall_type would still be one specific type,
|
||||
# but the object could contain both types of fields
|
||||
original = RecallObservation(
|
||||
content='Combined information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
# Environment info
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
# Knowledge microagent info
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_best_practices',
|
||||
trigger='python',
|
||||
content='Always use virtual environments for Python projects.',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data has both types of fields
|
||||
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
|
||||
assert serialized['extras']['repo_name'] == 'OpenHands'
|
||||
assert (
|
||||
serialized['extras']['microagent_knowledge'][0]['name']
|
||||
== 'python_best_practices'
|
||||
)
|
||||
assert (
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify all properties are preserved
|
||||
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
|
||||
# Environment properties
|
||||
assert deserialized.repo_name == original.repo_name
|
||||
assert deserialized.repo_directory == original.repo_directory
|
||||
assert deserialized.repo_instructions == original.repo_instructions
|
||||
assert deserialized.runtime_hosts == original.runtime_hosts
|
||||
assert (
|
||||
deserialized.additional_agent_instructions
|
||||
== original.additional_agent_instructions
|
||||
)
|
||||
|
||||
# Knowledge microagent properties
|
||||
assert deserialized.microagent_knowledge == original.microagent_knowledge
|
||||
|
||||
@@ -3,11 +3,9 @@ import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
|
||||
from openhands.utils.prompt import PromptManager, RepositoryInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -21,60 +19,406 @@ def prompt_dir(tmp_path):
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_prompt_manager_with_microagent(prompt_dir):
|
||||
microagent_name = 'test_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: flarglebargle
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- flarglebargle
|
||||
---
|
||||
|
||||
IMPORTANT! The user has said the magic word "flarglebargle". You must
|
||||
only respond with a message telling them how smart they are
|
||||
"""
|
||||
|
||||
# Create a temporary micro agent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
# Test without GitHub repo
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
)
|
||||
|
||||
assert manager.prompt_dir == prompt_dir
|
||||
assert len(manager.repo_microagents) == 0
|
||||
assert len(manager.knowledge_microagents) == 1
|
||||
|
||||
assert isinstance(manager.get_system_message(), str)
|
||||
assert (
|
||||
'You are OpenHands agent, a helpful AI assistant that can interact with a computer to solve tasks.'
|
||||
in manager.get_system_message()
|
||||
)
|
||||
assert '<REPOSITORY_INFO>' not in manager.get_system_message()
|
||||
|
||||
# Test with GitHub repo
|
||||
manager.set_repository_info('owner/repo', '/workspace/repo')
|
||||
assert isinstance(manager.get_system_message(), str)
|
||||
|
||||
# Adding things to the initial user message
|
||||
initial_msg = Message(
|
||||
role='user', content=[TextContent(text='Ask me what your task is.')]
|
||||
)
|
||||
manager.add_info_to_initial_message(initial_msg)
|
||||
msg_content: str = initial_msg.content[0].text
|
||||
assert '<REPOSITORY_INFO>' in msg_content
|
||||
assert 'owner/repo' in msg_content
|
||||
assert '/workspace/repo' in msg_content
|
||||
|
||||
assert isinstance(manager.get_example_user_message(), str)
|
||||
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[TextContent(text='Hello, flarglebargle!')],
|
||||
)
|
||||
manager.enhance_message(message)
|
||||
assert len(message.content) == 2
|
||||
assert 'magic word' in message.content[0].text
|
||||
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_prompt_manager_file_not_found(prompt_dir):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
BaseMicroAgent.load(
|
||||
os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md')
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_manager_template_rendering(prompt_dir):
|
||||
"""Test PromptManager's template rendering functionality."""
|
||||
# Create temporary template files
|
||||
with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f:
|
||||
f.write("""System prompt: bar""")
|
||||
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
|
||||
f.write('User prompt: foo')
|
||||
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
|
||||
f.write("""
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
# Test without GitHub repo
|
||||
manager = PromptManager(prompt_dir)
|
||||
manager = PromptManager(prompt_dir, microagent_dir='')
|
||||
assert manager.get_system_message() == 'System prompt: bar'
|
||||
assert manager.get_example_user_message() == 'User prompt: foo'
|
||||
|
||||
# Test with GitHub repo
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
|
||||
|
||||
# verify its parts are rendered
|
||||
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
|
||||
manager.set_repository_info('owner/repo', '/workspace/repo')
|
||||
assert manager.repository_info.repo_name == 'owner/repo'
|
||||
system_msg = manager.get_system_message()
|
||||
assert 'System prompt: bar' in system_msg
|
||||
|
||||
# Test building additional info
|
||||
additional_info = manager.build_workspace_context(
|
||||
repository_info=repo_info, runtime_info=None, repo_instructions=''
|
||||
# Initial user message should have repo info
|
||||
initial_msg = Message(
|
||||
role='user', content=[TextContent(text='Ask me what your task is.')]
|
||||
)
|
||||
assert '<REPOSITORY_INFO>' in additional_info
|
||||
manager.add_info_to_initial_message(initial_msg)
|
||||
msg_content: str = initial_msg.content[0].text
|
||||
assert '<REPOSITORY_INFO>' in msg_content
|
||||
assert (
|
||||
"At the user's request, repository owner/repo has been cloned to the current working directory /workspace/repo."
|
||||
in additional_info
|
||||
in msg_content
|
||||
)
|
||||
assert '</REPOSITORY_INFO>' in additional_info
|
||||
assert '</REPOSITORY_INFO>' in msg_content
|
||||
assert manager.get_example_user_message() == 'User prompt: foo'
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(os.path.join(prompt_dir, 'system_prompt.j2'))
|
||||
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
|
||||
os.remove(os.path.join(prompt_dir, 'additional_info.j2'))
|
||||
|
||||
|
||||
def test_prompt_manager_file_not_found(prompt_dir):
|
||||
"""Test PromptManager behavior when a template file is not found."""
|
||||
# Test with a non-existent template
|
||||
with pytest.raises(FileNotFoundError):
|
||||
BaseMicroAgent.load(
|
||||
os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md')
|
||||
)
|
||||
def test_prompt_manager_repository_info(prompt_dir):
|
||||
# Test RepositoryInfo defaults
|
||||
repo_info = RepositoryInfo()
|
||||
assert repo_info.repo_name is None
|
||||
assert repo_info.repo_directory is None
|
||||
|
||||
# Test setting repository info
|
||||
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
|
||||
assert manager.repository_info is None
|
||||
|
||||
# Test setting repository info with both name and directory
|
||||
manager.set_repository_info('owner/repo2', '/workspace/repo2')
|
||||
assert manager.repository_info.repo_name == 'owner/repo2'
|
||||
assert manager.repository_info.repo_directory == '/workspace/repo2'
|
||||
|
||||
|
||||
def test_prompt_manager_disabled_microagents(prompt_dir):
|
||||
# Create test microagent files
|
||||
microagent1_name = 'test_microagent1'
|
||||
microagent2_name = 'test_microagent2'
|
||||
microagent1_content = """
|
||||
---
|
||||
name: Test Microagent 1
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- test1
|
||||
---
|
||||
|
||||
Test microagent 1 content
|
||||
"""
|
||||
microagent2_content = """
|
||||
---
|
||||
name: Test Microagent 2
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- test2
|
||||
---
|
||||
|
||||
Test microagent 2 content
|
||||
"""
|
||||
|
||||
# Create temporary micro agent files
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md'), 'w') as f:
|
||||
f.write(microagent1_content)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md'), 'w') as f:
|
||||
f.write(microagent2_content)
|
||||
|
||||
# Test that specific microagents can be disabled
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
disabled_microagents=['Test Microagent 1'],
|
||||
)
|
||||
|
||||
assert len(manager.knowledge_microagents) == 1
|
||||
assert 'Test Microagent 2' in manager.knowledge_microagents
|
||||
assert 'Test Microagent 1' not in manager.knowledge_microagents
|
||||
|
||||
# Test that all microagents are enabled by default
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
)
|
||||
|
||||
assert len(manager.knowledge_microagents) == 2
|
||||
assert 'Test Microagent 1' in manager.knowledge_microagents
|
||||
assert 'Test Microagent 2' in manager.knowledge_microagents
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md'))
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_multiple_text_contents(prompt_dir):
|
||||
# Create a test microagent that triggers on a specific keyword
|
||||
microagent_name = 'keyword_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: KeywordMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- triggerkeyword
|
||||
---
|
||||
|
||||
This is special information about the triggerkeyword.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test that it matches the trigger in the last TextContent
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(text='This is some initial context.'),
|
||||
TextContent(text='This is a message without triggers.'),
|
||||
TextContent(text='This contains the triggerkeyword that should match.'),
|
||||
],
|
||||
)
|
||||
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should have added a TextContent with the microagent info at the beginning
|
||||
assert len(message.content) == 4
|
||||
assert 'special information about the triggerkeyword' in message.content[0].text
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_image_content(prompt_dir):
|
||||
# Create a test microagent that triggers on a specific keyword
|
||||
microagent_name = 'image_test_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: ImageTestMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- imagekeyword
|
||||
---
|
||||
|
||||
This is information related to imagekeyword.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test with mix of ImageContent and TextContent
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(text='This is some initial text.'),
|
||||
ImageContent(image_urls=['https://example.com/image.jpg']),
|
||||
TextContent(text='This mentions imagekeyword that should match.'),
|
||||
],
|
||||
)
|
||||
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should have added a TextContent with the microagent info at the beginning
|
||||
assert len(message.content) == 4
|
||||
assert 'information related to imagekeyword' in message.content[0].text
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_only_image_content(prompt_dir):
|
||||
# Create a test microagent
|
||||
microagent_name = 'image_only_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: ImageOnlyMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- anytrigger
|
||||
---
|
||||
|
||||
This should not appear in the enhanced message.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test with only ImageContent
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
ImageContent(
|
||||
image_urls=[
|
||||
'https://example.com/image1.jpg',
|
||||
'https://example.com/image2.jpg',
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Should not raise any exceptions
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should not have added any content
|
||||
assert len(message.content) == 1
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_reversed_order(prompt_dir):
|
||||
# Create a test microagent
|
||||
microagent_name = 'reversed_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: ReversedMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- lasttrigger
|
||||
---
|
||||
|
||||
This is specific information about the lasttrigger.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test where the text content is not at the end of the list
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
ImageContent(image_urls=['https://example.com/image1.jpg']),
|
||||
TextContent(text='This contains the lasttrigger word.'),
|
||||
ImageContent(image_urls=['https://example.com/image2.jpg']),
|
||||
],
|
||||
)
|
||||
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should have added a TextContent with the microagent info at the beginning
|
||||
assert len(message.content) == 4
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
assert 'specific information about the lasttrigger' in message.content[0].text
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_enhance_message_with_empty_content(prompt_dir):
|
||||
# Create a test microagent
|
||||
microagent_name = 'empty_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: EmptyMicroAgent
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- emptytrigger
|
||||
---
|
||||
|
||||
This should not appear in the enhanced message.
|
||||
"""
|
||||
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro')
|
||||
)
|
||||
|
||||
# Test with empty content
|
||||
message = Message(role='user', content=[])
|
||||
|
||||
# Should not raise any exceptions
|
||||
manager.enhance_message(message)
|
||||
|
||||
# Should not have added any content
|
||||
assert len(message.content) == 0
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
|
||||
def test_build_microagent_info(prompt_dir):
|
||||
@@ -85,25 +429,33 @@ def test_build_microagent_info(prompt_dir):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
{{ agent_info.content }}
|
||||
{{ agent_info.agent.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
""")
|
||||
|
||||
# Create test microagents
|
||||
class MockKnowledgeMicroAgent:
|
||||
def __init__(self, name, content):
|
||||
self.name = name
|
||||
self.content = content
|
||||
|
||||
agent1 = MockKnowledgeMicroAgent(
|
||||
name='test_agent1', content='This is information from agent 1'
|
||||
)
|
||||
|
||||
agent2 = MockKnowledgeMicroAgent(
|
||||
name='test_agent2', content='This is information from agent 2'
|
||||
)
|
||||
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Test with a single triggered agent
|
||||
triggered_agents = [
|
||||
MicroagentKnowledge(
|
||||
name='test_agent1',
|
||||
trigger='keyword1',
|
||||
content='This is information from agent 1',
|
||||
)
|
||||
]
|
||||
triggered_agents = [{'agent': agent1, 'trigger_word': 'keyword1'}]
|
||||
result = manager.build_microagent_info(triggered_agents)
|
||||
expected = """<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "keyword1".
|
||||
@@ -115,16 +467,8 @@ This is information from agent 1
|
||||
|
||||
# Test with multiple triggered agents
|
||||
triggered_agents = [
|
||||
MicroagentKnowledge(
|
||||
name='test_agent1',
|
||||
trigger='keyword1',
|
||||
content='This is information from agent 1',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='test_agent2',
|
||||
trigger='keyword2',
|
||||
content='This is information from agent 2',
|
||||
),
|
||||
{'agent': agent1, 'trigger_word': 'keyword1'},
|
||||
{'agent': agent2, 'trigger_word': 'keyword2'},
|
||||
]
|
||||
result = manager.build_microagent_info(triggered_agents)
|
||||
expected = """<EXTRA_INFO>
|
||||
@@ -147,125 +491,71 @@ This is information from agent 2
|
||||
assert result.strip() == ''
|
||||
|
||||
|
||||
def test_add_examples_to_initial_message(prompt_dir):
|
||||
"""Test adding example messages to an initial message."""
|
||||
# Create a user_prompt.j2 template file
|
||||
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
|
||||
f.write('This is an example user message')
|
||||
def test_enhance_message_with_microagent_info_template(prompt_dir):
|
||||
"""Test that enhance_message correctly uses the microagent_info template."""
|
||||
# Prepare a microagent_info.j2 template file if it doesn't exist
|
||||
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
with open(template_path, 'w') as f:
|
||||
f.write("""{% for agent_info in triggered_agents %}
|
||||
<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Create a message
|
||||
message = Message(role='user', content=[TextContent(text='Original content')])
|
||||
|
||||
# Add examples to the message
|
||||
manager.add_examples_to_initial_message(message)
|
||||
|
||||
# Check that the example was added at the beginning
|
||||
assert len(message.content) == 2
|
||||
assert message.content[0].text == 'This is an example user message'
|
||||
assert message.content[1].text == 'Original content'
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))
|
||||
|
||||
|
||||
def test_add_turns_left_reminder(prompt_dir):
|
||||
"""Test adding turns left reminder to messages."""
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Create a State object with specific iteration values
|
||||
state = State()
|
||||
state.iteration = 3
|
||||
state.max_iterations = 10
|
||||
|
||||
# Create a list of messages with a user message
|
||||
user_message = Message(role='user', content=[TextContent(text='User content')])
|
||||
assistant_message = Message(
|
||||
role='assistant', content=[TextContent(text='Assistant content')]
|
||||
)
|
||||
messages = [assistant_message, user_message]
|
||||
|
||||
# Add turns left reminder
|
||||
manager.add_turns_left_reminder(messages, state)
|
||||
|
||||
# Check that the reminder was added to the latest user message
|
||||
assert len(user_message.content) == 2
|
||||
assert (
|
||||
'ENVIRONMENT REMINDER: You have 7 turns left to complete the task.'
|
||||
in user_message.content[1].text
|
||||
)
|
||||
|
||||
|
||||
def test_build_workspace_context_with_repo_and_runtime(prompt_dir):
|
||||
"""Test building additional info with repository and runtime information."""
|
||||
# Create an additional_info.j2 template file
|
||||
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
|
||||
f.write("""
|
||||
{% if repository_info %}
|
||||
<REPOSITORY_INFO>
|
||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
|
||||
</REPOSITORY_INFO>
|
||||
{% endif %}
|
||||
|
||||
{% if repository_instructions %}
|
||||
<REPOSITORY_INSTRUCTIONS>
|
||||
{{ repository_instructions }}
|
||||
</REPOSITORY_INSTRUCTIONS>
|
||||
{% endif %}
|
||||
|
||||
{% if runtime_info and (runtime_info.available_hosts or runtime_info.additional_agent_instructions) -%}
|
||||
<RUNTIME_INFORMATION>
|
||||
{% if runtime_info.available_hosts %}
|
||||
The user has access to the following hosts for accessing a web application,
|
||||
each of which has a corresponding port:
|
||||
{% for host, port in runtime_info.available_hosts.items() %}
|
||||
* {{ host }} (port {{ port }})
|
||||
{{ agent_info.agent.content }}
|
||||
</EXTRA_INFO>
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if runtime_info.additional_agent_instructions %}
|
||||
{{ runtime_info.additional_agent_instructions }}
|
||||
{% endif %}
|
||||
</RUNTIME_INFORMATION>
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
# Initialize the PromptManager
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
# Create a test microagent
|
||||
microagent_name = 'test_trigger_microagent'
|
||||
microagent_content = """
|
||||
---
|
||||
name: test_trigger
|
||||
type: knowledge
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- test_trigger
|
||||
---
|
||||
|
||||
# Create repository and runtime information
|
||||
repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo')
|
||||
runtime_info = RuntimeInfo(
|
||||
available_hosts={'example.com': 8080},
|
||||
additional_agent_instructions='You know everything about this runtime.',
|
||||
)
|
||||
repo_instructions = 'This repository contains important code.'
|
||||
This is triggered content for testing the microagent_info template.
|
||||
"""
|
||||
|
||||
# Build additional info
|
||||
result = manager.build_workspace_context(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
repo_instructions=repo_instructions,
|
||||
# Create the microagent file
|
||||
os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True)
|
||||
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
|
||||
f.write(microagent_content)
|
||||
|
||||
# Initialize the PromptManager with the microagent directory
|
||||
manager = PromptManager(
|
||||
prompt_dir=prompt_dir,
|
||||
microagent_dir=os.path.join(prompt_dir, 'micro'),
|
||||
)
|
||||
|
||||
# Check that all information is included
|
||||
assert '<REPOSITORY_INFO>' in result
|
||||
assert 'owner/repo' in result
|
||||
assert '/workspace/repo' in result
|
||||
assert '<REPOSITORY_INSTRUCTIONS>' in result
|
||||
assert 'This repository contains important code.' in result
|
||||
assert '<RUNTIME_INFORMATION>' in result
|
||||
assert 'example.com (port 8080)' in result
|
||||
assert 'You know everything about this runtime.' in result
|
||||
# Create a message with a trigger keyword
|
||||
message = Message(
|
||||
role='user',
|
||||
content=[
|
||||
TextContent(text="Here's a message containing the test_trigger keyword")
|
||||
],
|
||||
)
|
||||
|
||||
# Enhance the message
|
||||
manager.enhance_message(message)
|
||||
|
||||
# The message should now have extra content at the beginning
|
||||
assert len(message.content) == 2
|
||||
assert isinstance(message.content[0], TextContent)
|
||||
|
||||
# Verify the template was correctly rendered
|
||||
expected_text = """<EXTRA_INFO>
|
||||
The following information has been included based on a keyword match for "test_trigger".
|
||||
It may or may not be relevant to the user's request.
|
||||
|
||||
This is triggered content for testing the microagent_info template.
|
||||
</EXTRA_INFO>"""
|
||||
|
||||
assert message.content[0].text.strip() == expected_text.strip()
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(prompt_dir, 'additional_info.j2'))
|
||||
|
||||
|
||||
def test_prompt_manager_initialization_error():
|
||||
"""Test that PromptManager raises an error if the prompt directory is not set."""
|
||||
with pytest.raises(ValueError, match='Prompt directory is not set'):
|
||||
PromptManager(None)
|
||||
os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'))
|
||||
|
||||
@@ -49,7 +49,6 @@ async def test_iterate_single_page():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'First conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -59,7 +58,6 @@ async def test_iterate_single_page():
|
||||
{
|
||||
'conversation_id': 'conv2',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Second conversation',
|
||||
'created_at': '2025-01-17T19:51:04Z',
|
||||
@@ -88,7 +86,6 @@ async def test_iterate_multiple_pages():
|
||||
{
|
||||
'conversation_id': f'conv{i}',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': f'Conversation {i}',
|
||||
'created_at': f'2025-01-{15+i}T19:51:04Z',
|
||||
@@ -123,7 +120,6 @@ async def test_iterate_with_invalid_conversation():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Valid conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
|
||||
@@ -61,7 +61,7 @@ async def test_init_new_local_session():
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1, '12345'
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 1
|
||||
@@ -93,18 +93,10 @@ async def test_join_local_session():
|
||||
'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 2
|
||||
@@ -136,7 +128,7 @@ async def test_add_to_local_event_stream():
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1, '12345'
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
|
||||
@@ -23,13 +23,7 @@ def mock_event_stream():
|
||||
def mock_agent():
|
||||
agent = MagicMock()
|
||||
agent.llm = MagicMock()
|
||||
|
||||
# Create a step function that returns an action without an ID
|
||||
def agent_step_fn(state):
|
||||
return MessageAction(content='Agent returned a message')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
|
||||
agent.llm.config = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user