Compare commits

...

28 Commits

Author SHA1 Message Date
openhands
4ea4fb2476 Fix GitHub API authorization headers to use Bearer consistently (issue #6032) 2025-03-17 15:40:56 +00:00
Engel Nyst
a4b836b5f9 Don't try to send the new events in the UI (#7277) 2025-03-17 14:50:22 +01:00
Xingyao Wang
a4d632498c SWE-Gym rollout stability fix & using a validated SWE-Gym set (#7182)
Co-authored-by: Robert Brennan <accounts@rbren.io>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-03-17 21:15:01 +08:00
Engel Nyst
4f017081fc Quick fix docs (#7299)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-17 05:50:05 +00:00
Engel Nyst
51fb1fae88 RecallObservations (#7292) 2025-03-17 03:18:22 +01:00
Graham Neubig
106b230fea Update Slack invitation links (#7296)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-17 02:06:48 +00:00
Xingyao Wang
9b262dd057 fix retry on ConnectionError & retry on remote runtime by default (#7294) 2025-03-17 01:18:54 +00:00
chuckbutkus
8074b261d3 Move current user_id to github_user_id and create a new user_id field (#7231)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
2025-03-16 16:32:27 -04:00
dependabot[bot]
999a59f938 chore(deps): bump the version-all group with 5 updates (#7253)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-16 19:51:08 +00:00
chuckbutkus
fbba57d3b5 Fix saving of settings (#7282) 2025-03-16 19:06:46 +00:00
Engel Nyst
3f6c8a2338 Fix visual browsing (#7278)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-16 16:50:25 +01:00
Engel Nyst
dd09d46ccb Remove DelegatorAgent (fix #7280)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-16 16:49:28 +01:00
tofarr
8897b45eeb Fix for too much reaction in logs (#7276) 2025-03-16 08:21:30 -06:00
Ryan H. Tran
30109e8f20 Separate tool descriptions to support models with limited description length (#7258) 2025-03-16 09:48:13 +01:00
Engel Nyst
cc45f5d9c3 Add RecallActions and observations for retrieval of prompt extensions (#6909)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Calvin Smith <email@cjsmith.io>
2025-03-15 21:48:37 +01:00
tofarr
e34a771e66 Fix for issue where initial command fails (#7254) 2025-03-14 14:49:57 -06:00
tofarr
ec763f8105 Fix for error where credits is accessed even when billing is disabled (#7250) 2025-03-14 15:10:54 +00:00
Ryan H. Tran
165c0cc42e Add doc for local runtime (#7234)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
2025-03-14 22:09:33 +08:00
dependabot[bot]
1b4f15235e chore(deps): bump the version-all group with 4 updates (#7241)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-13 22:15:47 +01:00
Calvin Smith
303b7ab180 (fix): Conditional imports resolved in SWE-bench eval script while multiprocessing enabled (#7244)
Co-authored-by: Calvin Smith <calvin@all-hands.dev>
2025-03-13 13:29:11 -06:00
Rohit Malhotra
78d185b102 [Feat]: Support Gitlab PAT (#7064)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-13 14:44:49 -04:00
Ryan H. Tran
300bfbdf2d Upgrade openhands-aci to 0.2.6 (#7233) 2025-03-14 02:10:59 +08:00
Xingyao Wang
e2f414bf26 chore: update doc for allhands doc (#7242)
Co-authored-by: mamoodi <mamoodiha@gmail.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
2025-03-14 00:32:17 +08:00
Calvin Smith
3b955dd9d5 (fix): Improve formatting of summarizing LLM inputs (#7239)
Co-authored-by: Calvin Smith <calvin@all-hands.dev>
2025-03-13 10:21:29 -06:00
sp.wack
f1eb1f59c3 hotfix(frontend): Fix overriding tailwind animation classes (#7243) 2025-03-13 16:03:54 +00:00
sp.wack
e1f6929d98 feat: saas new user modal (#7098)
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-13 19:15:57 +04:00
Engel Nyst
2a7f926591 Detect condensation loops at 10 repetitions, not 3 (#7237) 2025-03-13 14:32:01 +00:00
Xingyao Wang
b8daab721d Update agent message to use first-person perspective (#7197)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-13 04:39:24 +08:00
125 changed files with 8596 additions and 1653 deletions

View File

@@ -12,7 +12,7 @@
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
<br/>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
<br/>
@@ -96,7 +96,7 @@ troubleshooting resources, and advanced configuration options.
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
through Slack, so this is the best place to start, but we also are happy to have you contact us on Discord or Github:
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg) - Here we talk about research, architecture, and future development.
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw) - Here we talk about research, architecture, and future development.
- [Join our Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.

View File

@@ -42,7 +42,7 @@ Explorez le code source d'OpenHands sur [GitHub](https://github.com/All-Hands-AI
/>
</a>
<br></br>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
<img
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
alt="Join our Slack community"

View File

@@ -42,7 +42,7 @@ OpenHands 是一个**自主 AI 软件工程师**,能够执行复杂的工程
/>
</a>
<br></br>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
<img
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
alt="Join our Slack community"

View File

@@ -10,6 +10,9 @@ We also support "remote" runtimes, which are typically managed by third-parties.
They can make setup a bit simpler and more scalable, especially
if you're running many OpenHands conversations in parallel (e.g. to do evaluation).
Additionally, we provide a "local" runtime that runs directly on your machine without Docker,
which can be useful in controlled environments like CI pipelines.
## Docker Runtime
This is the default Runtime that's used when you start OpenHands. You might notice
some flags being passed to `docker run` that make this possible:
@@ -56,11 +59,12 @@ any files that are mounted into its workspace.
This setup can cause some issues with file permissions (hence the `SANDBOX_USER_ID` variable)
but seems to work well on most systems.
## All Hands Runtime
The All Hands Runtime is currently in beta. You can request access by joining
the #remote-runtime-limited-beta channel on Slack ([see the README](https://github.com/All-Hands-AI/OpenHands?tab=readme-ov-file#-how-to-join-the-community) for an invite).
## OpenHands Remote Runtime
To use the All Hands Runtime, set the following environment variables when
OpenHands Remote Runtime is currently in beta (read [here](https://runtime.all-hands.dev/) for more details), it allows you to launch runtimes in parallel in the cloud.
Fill out [this form](https://docs.google.com/forms/d/e/1FAIpQLSckVz_JFwg2_mOxNZjCtr7aoBFI2Mwdan3f75J_TrdMS1JV2g/viewform) to apply if you want to try this out!
To use the OpenHands Remote Runtime, set the following environment variables when
starting OpenHands:
```bash
@@ -117,3 +121,66 @@ bash -i <(curl -sL https://get.daytona.io/openhands)
Once executed, OpenHands should be running locally and ready for use.
For more details and manual initialization, view the entire [README.md](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/impl/daytona/README.md)
## Local Runtime
The Local Runtime allows the OpenHands agent to execute actions directly on your local machine without using Docker. This runtime is primarily intended for controlled environments like CI pipelines or testing scenarios where Docker is not available.
:::caution
**Security Warning**: The Local Runtime runs without any sandbox isolation. The agent can directly access and modify files on your machine. Only use this runtime in controlled environments or when you fully understand the security implications.
:::
### Prerequisites
Before using the Local Runtime, ensure you have the following dependencies installed:
1. You have followed the [Development setup instructions](https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md).
2. tmux is available on your system.
### Configuration
To use the Local Runtime, besides required configurations like the model, API key, you'll need to set the following options via environment variables or the [config.toml file](https://github.com/All-Hands-AI/OpenHands/blob/main/config.template.toml) when starting OpenHands:
- Via environment variables:
```bash
# Required
export RUNTIME=local
# Optional but recommended
export WORKSPACE_BASE=/path/to/your/workspace
```
- Via `config.toml`:
```toml
[core]
runtime = "local"
workspace_base = "/path/to/your/workspace"
```
If `WORKSPACE_BASE` is not set, the runtime will create a temporary directory for the agent to work in.
### Example Usage
Here's an example of how to start OpenHands with the Local Runtime in Headless Mode:
```bash
# Set the runtime type to local
export RUNTIME=local
# Optionally set a workspace directory
export WORKSPACE_BASE=/path/to/your/project
# Start OpenHands
poetry run python -m openhands.core.main -t "write a bash script that prints hi"
```
### Use Cases
The Local Runtime is particularly useful for:
- CI/CD pipelines where Docker is not available.
- Testing and development of OpenHands itself.
- Environments where container usage is restricted.
- Scenarios where direct file system access is required.

View File

@@ -8,7 +8,7 @@ function CustomFooter() {
<footer className="custom-footer">
<div className="footer-content">
<div className="footer-icons">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg" target="_blank" rel="noopener noreferrer">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw" target="_blank" rel="noopener noreferrer">
<FaSlack />
</a>
<a href="https://discord.gg/ESHStjSjD4" target="_blank" rel="noopener noreferrer">

View File

@@ -46,7 +46,7 @@ export function HomepageHeader() {
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" /></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License" /></a>
<br/>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community" /></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits" /></a>
<br/>

View File

@@ -1,9 +1,12 @@
import copy
import json
import os
import subprocess
import tempfile
import time
from dataclasses import dataclass
from functools import partial
from typing import Callable
import pandas as pd
from tqdm import tqdm
@@ -91,12 +94,22 @@ def get_config(metadata: EvalMetadata, instance: pd.Series) -> AppConfig:
return config
@dataclass
class ConditionalImports:
"""We instantiate the values in this dataclass differently if we're evaluating SWE-bench or SWE-Gym."""
get_eval_report: Callable
APPLY_PATCH_FAIL: str
APPLY_PATCH_PASS: str
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
log_dir: str | None = None,
runtime_failure_count: int = 0,
conditional_imports: ConditionalImports | None = None,
) -> EvalOutput:
"""
Evaluate agent performance on a SWE-bench problem instance.
@@ -108,9 +121,18 @@ def process_instance(
log_dir (str | None, default=None): Path to directory where log files will be written. Must
be provided if `reset_logger` is set.
conditional_imports: A dataclass containing values that are imported differently based on
whether we're evaluating SWE-bench or SWE-Gym.
Raises:
AssertionError: if the `reset_logger` flag is set without a provided log directory.
AssertionError: if `conditional_imports` is not provided.
"""
assert (
conditional_imports is not None
), 'conditional_imports must be provided to run process_instance using multiprocessing'
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:
assert (
@@ -124,7 +146,7 @@ def process_instance(
config = get_config(metadata, instance)
instance_id = instance.instance_id
model_patch = instance['model_patch']
test_spec: TestSpec = instance['test_spec']
test_spec = instance['test_spec']
logger.info(f'Starting evaluation for instance {instance_id}.')
if 'test_result' not in instance.keys():
@@ -154,6 +176,11 @@ def process_instance(
logger.warning(
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
)
metadata = copy.deepcopy(metadata)
metadata.details['runtime_failure_count'] = runtime_failure_count
metadata.details['remote_runtime_resource_factor'] = (
config.sandbox.remote_runtime_resource_factor
)
try:
runtime = create_runtime(config)
@@ -196,7 +223,9 @@ def process_instance(
instance['test_result']['apply_patch_output'] = apply_patch_output
if 'APPLY_PATCH_FAIL' in apply_patch_output:
logger.info(f'[{instance_id}] {APPLY_PATCH_FAIL}:\n{apply_patch_output}')
logger.info(
f'[{instance_id}] {conditional_imports.APPLY_PATCH_FAIL}:\n{apply_patch_output}'
)
instance['test_result']['report']['failed_apply_patch'] = True
return EvalOutput(
@@ -205,7 +234,9 @@ def process_instance(
metadata=metadata,
)
elif 'APPLY_PATCH_PASS' in apply_patch_output:
logger.info(f'[{instance_id}] {APPLY_PATCH_PASS}:\n{apply_patch_output}')
logger.info(
f'[{instance_id}] {conditional_imports.APPLY_PATCH_PASS}:\n{apply_patch_output}'
)
# Run eval script in background and save output to log file
log_file = '/tmp/eval_output.log'
@@ -271,14 +302,20 @@ def process_instance(
with open(test_output_path, 'w') as f:
f.write(test_output)
try:
_report = get_eval_report(
extra_kwargs = {}
if 'SWE-Gym' in metadata.dataset:
# SWE-Gym uses a different version of the package, hence a different eval report argument
extra_kwargs['log_path'] = test_output_path
else:
extra_kwargs['test_log_path'] = test_output_path
_report = conditional_imports.get_eval_report(
test_spec=test_spec,
prediction={
'model_patch': model_patch,
'instance_id': instance_id,
},
test_log_path=test_output_path,
include_tests_status=True,
**extra_kwargs,
)
report = _report[instance_id]
logger.info(
@@ -345,7 +382,6 @@ if __name__ == '__main__':
)
from swegym.harness.test_spec import (
SWEbenchInstance,
TestSpec,
make_test_spec,
)
from swegym.harness.utils import load_swebench_dataset
@@ -357,7 +393,6 @@ if __name__ == '__main__':
)
from swebench.harness.test_spec.test_spec import (
SWEbenchInstance,
TestSpec,
make_test_spec,
)
from swebench.harness.utils import load_swebench_dataset
@@ -440,12 +475,21 @@ if __name__ == '__main__':
.decode('utf-8')
.strip(), # Current commit
dataset=args.dataset, # Dataset name from args
details={},
)
# The evaluation harness constrains the signature of `process_instance_func` but we need to
# pass extra information. Build a new function object to avoid issues with multiprocessing.
process_instance_func = partial(
process_instance, log_dir=output_file.replace('.jsonl', '.logs')
process_instance,
log_dir=output_file.replace('.jsonl', '.logs'),
# We have to explicitly pass these imports to the process_instance function, otherwise
# they won't be available in the multiprocessing context.
conditional_imports=ConditionalImports(
get_eval_report=get_eval_report,
APPLY_PATCH_FAIL=APPLY_PATCH_FAIL,
APPLY_PATCH_PASS=APPLY_PATCH_PASS,
),
)
run_evaluation(

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,7 @@ def get_resource_mapping(dataset_name: str) -> dict[str, float]:
if dataset_name not in _global_resource_mapping:
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
if not os.path.exists(file_path):
logger.warning(f'Resource mapping for {dataset_name} not found.')
logger.info(f'Resource mapping for {dataset_name} not found.')
return None
with open(file_path, 'r') as f:

View File

@@ -1,4 +1,5 @@
import asyncio
import copy
import json
import os
import tempfile
@@ -149,7 +150,8 @@ def get_config(
) -> AppConfig:
# We use a different instance image for the each instance of swe-bench eval
use_official_image = bool(
'verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower()
('verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower())
and 'swe-gym' not in metadata.dataset.lower()
)
base_container_image = get_instance_docker_image(
instance['instance_id'], use_official_image
@@ -475,6 +477,13 @@ def process_instance(
logger.warning(
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
)
metadata = copy.deepcopy(metadata)
metadata.details['runtime_failure_count'] = runtime_failure_count
metadata.details['remote_runtime_resource_factor'] = (
config.sandbox.remote_runtime_resource_factor
)
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
@@ -560,20 +569,6 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
return dataset
# A list of instances that are known to be tricky to infer
# (will cause runtime failure even with resource factor = 8)
SWEGYM_EXCLUDE_IDS = [
'dask__dask-10422',
'pandas-dev__pandas-50548',
'pandas-dev__pandas-53672',
'pandas-dev__pandas-54174',
'pandas-dev__pandas-55518',
'pandas-dev__pandas-58383',
'pydata__xarray-6721',
'pytest-dev__pytest-10081',
'pytest-dev__pytest-7236',
]
if __name__ == '__main__':
parser = get_parser()
parser.add_argument(
@@ -598,11 +593,20 @@ if __name__ == '__main__':
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
)
if 'SWE-Gym' in args.dataset:
swe_bench_tests = swe_bench_tests[
~swe_bench_tests['instance_id'].isin(SWEGYM_EXCLUDE_IDS)
]
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'split',
'swegym_verified_instances.json',
),
'r',
) as f:
swegym_verified_instances = json.load(f)
swe_bench_tests = swe_bench_tests[
swe_bench_tests['instance_id'].isin(swegym_verified_instances)
]
logger.info(
f'{len(swe_bench_tests)} tasks left after excluding SWE-Gym excluded tasks'
f'{len(swe_bench_tests)} tasks left after filtering for SWE-Gym verified instances'
)
llm_config = None

View File

@@ -9,7 +9,7 @@ parser.add_argument(
'--dataset_name',
type=str,
help='Name of the dataset to download',
default='princeton-nlp/SWE-bench_Lite',
default='princeton-nlp/SWE-bench_Verified',
)
parser.add_argument('--split', type=str, help='Split to download', default='test')
args = parser.parse_args()
@@ -20,7 +20,12 @@ print(
f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}'
)
patches = [
{'instance_id': row['instance_id'], 'model_patch': row['patch']} for row in dataset
{
'instance_id': row['instance_id'],
'model_patch': row['patch'],
'model_name_or_path': 'gold',
}
for row in dataset
]
print(f'{len(patches)} gold patches loaded')
pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records')

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,6 @@ from openhands.utils.async_utils import call_async_from_sync
FAKE_RESPONSES = {
'CodeActAgent': fake_user_response,
'DelegatorAgent': fake_user_response,
'VisualBrowsingAgent': fake_user_response,
}

View File

@@ -4,8 +4,10 @@ import userEvent from "@testing-library/user-event";
import { afterEach, beforeEach, describe, expect, it, test, vi } from "vitest";
import OpenHands from "#/api/open-hands";
import { PaymentForm } from "#/components/features/payment/payment-form";
import * as featureFlags from "#/utils/feature-flags";
describe("PaymentForm", () => {
const billingSettingsSpy = vi.spyOn(featureFlags, "BILLING_SETTINGS");
const getBalanceSpy = vi.spyOn(OpenHands, "getBalance");
const createCheckoutSessionSpy = vi.spyOn(OpenHands, "createCheckoutSession");
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
@@ -26,6 +28,7 @@ describe("PaymentForm", () => {
GITHUB_CLIENT_ID: "123",
POSTHOG_CLIENT_KEY: "456",
});
billingSettingsSpy.mockReturnValue(true);
});
afterEach(() => {

View File

@@ -26,33 +26,32 @@ const createAxiosNotFoundErrorObject = () =>
},
);
const getSettingsSpy = vi.spyOn(OpenHands, "getSettings");
const RouterStub = createRoutesStub([
{
// layout route
Component: MainApp,
path: "/",
children: [
{
// home route
Component: Home,
path: "/",
},
{
Component: SettingsScreen,
path: "/settings",
},
],
},
]);
afterEach(() => {
vi.clearAllMocks();
});
describe("Home Screen", () => {
const getSettingsSpy = vi.spyOn(OpenHands, "getSettings");
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
const RouterStub = createRoutesStub([
{
// layout route
Component: MainApp,
path: "/",
children: [
{
// home route
Component: Home,
path: "/",
},
{
Component: SettingsScreen,
path: "/settings",
},
],
},
]);
afterEach(() => {
vi.clearAllMocks();
});
it("should render the home screen", () => {
renderWithProviders(<RouterStub initialEntries={["/"]} />);
});
@@ -79,57 +78,82 @@ describe("Home Screen", () => {
const settingsScreen = await screen.findByTestId("settings-screen");
expect(settingsScreen).toBeInTheDocument();
});
});
describe("Settings 404", () => {
it("should open the settings modal if GET /settings fails with a 404", async () => {
const error = createAxiosNotFoundErrorObject();
getSettingsSpy.mockRejectedValue(error);
describe("Settings 404", () => {
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
renderWithProviders(<RouterStub initialEntries={["/"]} />);
it("should open the settings modal if GET /settings fails with a 404", async () => {
const error = createAxiosNotFoundErrorObject();
getSettingsSpy.mockRejectedValue(error);
const settingsModal = await screen.findByTestId("ai-config-modal");
expect(settingsModal).toBeInTheDocument();
});
renderWithProviders(<RouterStub initialEntries={["/"]} />);
it("should navigate to the settings screen when clicking the advanced settings button", async () => {
const error = createAxiosNotFoundErrorObject();
getSettingsSpy.mockRejectedValue(error);
const settingsModal = await screen.findByTestId("ai-config-modal");
expect(settingsModal).toBeInTheDocument();
});
const user = userEvent.setup();
renderWithProviders(<RouterStub initialEntries={["/"]} />);
it("should navigate to the settings screen when clicking the advanced settings button", async () => {
const error = createAxiosNotFoundErrorObject();
getSettingsSpy.mockRejectedValue(error);
const settingsScreen = screen.queryByTestId("settings-screen");
expect(settingsScreen).not.toBeInTheDocument();
const user = userEvent.setup();
renderWithProviders(<RouterStub initialEntries={["/"]} />);
const settingsModal = await screen.findByTestId("ai-config-modal");
expect(settingsModal).toBeInTheDocument();
const settingsScreen = screen.queryByTestId("settings-screen");
expect(settingsScreen).not.toBeInTheDocument();
const advancedSettingsButton = await screen.findByTestId(
"advanced-settings-link",
);
await user.click(advancedSettingsButton);
const settingsModal = await screen.findByTestId("ai-config-modal");
expect(settingsModal).toBeInTheDocument();
const settingsScreenAfter = await screen.findByTestId("settings-screen");
expect(settingsScreenAfter).toBeInTheDocument();
const advancedSettingsButton = await screen.findByTestId(
"advanced-settings-link",
);
await user.click(advancedSettingsButton);
const settingsModalAfter = screen.queryByTestId("ai-config-modal");
expect(settingsModalAfter).not.toBeInTheDocument();
});
const settingsScreenAfter = await screen.findByTestId("settings-screen");
expect(settingsScreenAfter).toBeInTheDocument();
it("should not open the settings modal if GET /settings fails but is SaaS mode", async () => {
// TODO: Remove HIDE_LLM_SETTINGS check once released
vi.spyOn(FeatureFlags, "HIDE_LLM_SETTINGS").mockReturnValue(true);
// @ts-expect-error - we only need APP_MODE for this test
getConfigSpy.mockResolvedValue({ APP_MODE: "saas" });
const error = createAxiosNotFoundErrorObject();
getSettingsSpy.mockRejectedValue(error);
const settingsModalAfter = screen.queryByTestId("ai-config-modal");
expect(settingsModalAfter).not.toBeInTheDocument();
});
renderWithProviders(<RouterStub initialEntries={["/"]} />);
it("should not open the settings modal if GET /settings fails but is SaaS mode", async () => {
// TODO: Remove HIDE_LLM_SETTINGS check once released
vi.spyOn(FeatureFlags, "HIDE_LLM_SETTINGS").mockReturnValue(true);
// @ts-expect-error - we only need APP_MODE for this test
getConfigSpy.mockResolvedValue({ APP_MODE: "saas" });
const error = createAxiosNotFoundErrorObject();
getSettingsSpy.mockRejectedValue(error);
// small hack to wait for the modal to not appear
await expect(
screen.findByTestId("ai-config-modal", {}, { timeout: 1000 }),
).rejects.toThrow();
});
renderWithProviders(<RouterStub initialEntries={["/"]} />);
// small hack to wait for the modal to not appear
await expect(
screen.findByTestId("ai-config-modal", {}, { timeout: 1000 }),
).rejects.toThrow();
});
});
describe("Setup Payment modal", () => {
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
afterEach(() => {
vi.resetAllMocks();
});
it("should only render if SaaS mode and is new user", async () => {
// @ts-expect-error - we only need the APP_MODE for this test
getConfigSpy.mockResolvedValue({
APP_MODE: "saas",
});
vi.spyOn(FeatureFlags, "BILLING_SETTINGS").mockReturnValue(true);
const error = createAxiosNotFoundErrorObject();
getSettingsSpy.mockRejectedValue(error);
renderWithProviders(<RouterStub initialEntries={["/"]} />);
const setupPaymentModal = await screen.findByTestId("proceed-to-stripe-button");
expect(setupPaymentModal).toBeInTheDocument();
});
});

View File

@@ -721,7 +721,7 @@ describe("Settings Screen", () => {
expect(saveSettingsSpy).toHaveBeenCalledWith(
expect.objectContaining({
llm_api_key: "", // empty because it's not set previously
github_token: undefined,
provider_tokens: undefined,
language: "no",
}),
);
@@ -758,7 +758,7 @@ describe("Settings Screen", () => {
expect(saveSettingsSpy).toHaveBeenCalledWith(
expect.objectContaining({
github_token: undefined,
provider_tokens: undefined,
llm_api_key: "", // empty because it's not set previously
llm_model: "openai/gpt-4o",
}),
@@ -801,7 +801,7 @@ describe("Settings Screen", () => {
expect(saveSettingsSpy).toHaveBeenCalledWith({
...mockCopy,
github_token: undefined, // not set
provider_tokens: undefined, // not set
llm_api_key: "", // reset as well
});
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();

View File

@@ -1,7 +1,9 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { handleStatusMessage } from "#/services/actions";
import { handleStatusMessage, handleActionMessage } from "#/services/actions";
import store from "#/store";
import { trackError } from "#/utils/error-handler";
import ActionType from "#/types/action-type";
import { ActionMessage } from "#/types/message";
// Mock dependencies
vi.mock("#/utils/error-handler", () => ({
@@ -56,4 +58,89 @@ describe("Actions Service", () => {
}));
});
});
describe("handleActionMessage", () => {
it("should use first-person perspective for task completion messages", () => {
// Test partial completion
const messagePartial: ActionMessage = {
id: 1,
action: ActionType.FINISH,
source: "agent",
message: "",
timestamp: new Date().toISOString(),
args: {
final_thought: "",
task_completed: "partial",
outputs: "",
thought: ""
}
};
// Mock implementation to capture the message
let capturedPartialMessage = "";
(store.dispatch as any).mockImplementation((action: any) => {
if (action.type === "chat/addAssistantMessage" &&
action.payload.includes("believe that the task was **completed partially**")) {
capturedPartialMessage = action.payload;
}
});
handleActionMessage(messagePartial);
expect(capturedPartialMessage).toContain("I believe that the task was **completed partially**");
// Test not completed
const messageNotCompleted: ActionMessage = {
id: 2,
action: ActionType.FINISH,
source: "agent",
message: "",
timestamp: new Date().toISOString(),
args: {
final_thought: "",
task_completed: "false",
outputs: "",
thought: ""
}
};
// Mock implementation to capture the message
let capturedNotCompletedMessage = "";
(store.dispatch as any).mockImplementation((action: any) => {
if (action.type === "chat/addAssistantMessage" &&
action.payload.includes("believe that the task was **not completed**")) {
capturedNotCompletedMessage = action.payload;
}
});
handleActionMessage(messageNotCompleted);
expect(capturedNotCompletedMessage).toContain("I believe that the task was **not completed**");
// Test completed successfully
const messageCompleted: ActionMessage = {
id: 3,
action: ActionType.FINISH,
source: "agent",
message: "",
timestamp: new Date().toISOString(),
args: {
final_thought: "",
task_completed: "true",
outputs: "",
thought: ""
}
};
// Mock implementation to capture the message
let capturedCompletedMessage = "";
(store.dispatch as any).mockImplementation((action: any) => {
if (action.type === "chat/addAssistantMessage" &&
action.payload.includes("believe that the task was **completed successfully**")) {
capturedCompletedMessage = action.payload;
}
});
handleActionMessage(messageCompleted);
expect(capturedCompletedMessage).toContain("I believe that the task was **completed successfully**");
});
});
});

View File

@@ -8,7 +8,7 @@
* - Please do NOT serve this file on production.
*/
const PACKAGE_VERSION = '2.7.0'
const PACKAGE_VERSION = '2.7.3'
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
const activeClientIds = new Set()

View File

@@ -281,6 +281,13 @@ class OpenHands {
return data.redirect_url;
}
static async createBillingSessionResponse(): Promise<string> {
const { data } = await openHands.post(
"/api/billing/create-customer-setup-session",
);
return data.redirect_url;
}
static async getBalance(): Promise<string> {
const { data } = await openHands.get<{ credits: string }>(
"/api/billing/credits",

View File

@@ -48,6 +48,7 @@ export interface GetConfigResponse {
APP_SLUG?: string;
GITHUB_CLIENT_ID: string;
POSTHOG_CLIENT_KEY: string;
STRIPE_PUBLISHABLE_KEY?: string;
}
export interface GetVSCodeUrlResponse {

View File

@@ -0,0 +1,48 @@
import { useMutation } from "@tanstack/react-query";
import { Trans, useTranslation } from "react-i18next";
import AllHandsLogo from "#/assets/branding/all-hands-logo.svg?react";
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
import { ModalBody } from "#/components/shared/modals/modal-body";
import OpenHands from "#/api/open-hands";
import { BrandButton } from "../settings/brand-button";
import { displayErrorToast } from "#/utils/custom-toast-handlers";
export function SetupPaymentModal() {
const { t } = useTranslation();
const { mutate, isPending } = useMutation({
mutationFn: OpenHands.createBillingSessionResponse,
onSuccess: (data) => {
window.location.href = data;
},
onError: () => {
displayErrorToast(t("BILLING$ERROR_WHILE_CREATING_SESSION"));
},
});
return (
<ModalBackdrop>
<ModalBody className="border border-tertiary">
<AllHandsLogo width={68} height={46} />
<div className="flex flex-col gap-2 w-full items-center text-center">
<h1 className="text-2xl font-bold">{t("BILLING$YOUVE_GOT_50")}</h1>
<p>
<Trans
i18nKey="BILLING$CLAIM_YOUR_50"
components={{ b: <strong /> }}
/>
</p>
</div>
<BrandButton
testId="proceed-to-stripe-button"
type="submit"
variant="primary"
className="w-full"
isDisabled={isPending}
onClick={mutate}
>
{t("BILLING$PROCEED_TO_STRIPE")}
</BrandButton>
</ModalBody>
</ModalBackdrop>
);
}

View File

@@ -61,7 +61,7 @@ export function Sidebar() {
displayErrorToast(
"Something went wrong while fetching settings. Please reload the page.",
);
} else if (settingsError?.status === 404) {
} else if (config?.APP_MODE === "oss" && settingsError?.status === 404) {
setSettingsModalIsOpen(true);
}
}, [

View File

@@ -17,7 +17,7 @@ const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
? ""
: settings.LLM_API_KEY?.trim() || undefined,
remote_runtime_resource_factor: settings.REMOTE_RUNTIME_RESOURCE_FACTOR,
github_token: settings.github_token,
provider_tokens: settings.provider_tokens,
unset_github_token: settings.unset_github_token,
enable_default_condenser: settings.ENABLE_DEFAULT_CONDENSER,
enable_sound_notifications: settings.ENABLE_SOUND_NOTIFICATIONS,

View File

@@ -1,6 +1,7 @@
import { useQuery } from "@tanstack/react-query";
import { useConfig } from "./use-config";
import OpenHands from "#/api/open-hands";
import { BILLING_SETTINGS } from "#/utils/feature-flags";
export const useBalance = () => {
const { data: config } = useConfig();
@@ -8,6 +9,6 @@ export const useBalance = () => {
return useQuery({
queryKey: ["user", "balance"],
queryFn: OpenHands.getBalance,
enabled: config?.APP_MODE === "saas",
enabled: config?.APP_MODE === "saas" && BILLING_SETTINGS(),
});
};

View File

@@ -21,6 +21,8 @@ const getSettingsQueryFn = async () => {
ENABLE_DEFAULT_CONDENSER: apiSettings.enable_default_condenser,
ENABLE_SOUND_NOTIFICATIONS: apiSettings.enable_sound_notifications,
USER_CONSENTS_TO_ANALYTICS: apiSettings.user_consents_to_analytics,
PROVIDER_TOKENS: apiSettings.provider_tokens,
IS_NEW_USER: false,
};
};

View File

@@ -312,4 +312,9 @@ export enum I18nKey {
BUTTON$MARK_NOT_HELPFUL = "BUTTON$MARK_NOT_HELPFUL",
BUTTON$EXPORT_CONVERSATION = "BUTTON$EXPORT_CONVERSATION",
BILLING$CLICK_TO_TOP_UP = "BILLING$CLICK_TO_TOP_UP",
BILLING$YOUVE_GOT_50 = "BILLING$YOUVE_GOT_50",
BILLING$ERROR_WHILE_CREATING_SESSION = "BILLING$ERROR_WHILE_CREATING_SESSION",
BILLING$CLAIM_YOUR_50 = "BILLING$CLAIM_YOUR_50",
BILLING$PROCEED_TO_STRIPE = "BILLING$PROCEED_TO_STRIPE",
BILLING$YOURE_IN = "BILLING$YOURE_IN",
}

View File

@@ -4647,5 +4647,80 @@
"fr": "Ajouter des fonds à votre compte",
"tr": "Hesabınıza bakiye ekleyin",
"de": "Guthaben zu Ihrem Konto hinzufügen"
},
"BILLING$YOUVE_GOT_50": {
"en": "You've got $50 in free OpenHands credits",
"ja": "OpenHandsの無料クレジット$50を獲得しました",
"zh-CN": "您获得了 $50 的 OpenHands 免费额度",
"zh-TW": "您獲得了 $50 的 OpenHands 免費額度",
"ko-KR": "OpenHands 무료 크레딧 $50를 받았습니다",
"no": "Du har fått $50 i gratis OpenHands-kreditter",
"it": "Hai ottenuto $50 in crediti gratuiti OpenHands",
"pt": "Você ganhou $50 em créditos gratuitos OpenHands",
"es": "Has recibido $50 en créditos gratuitos de OpenHands",
"ar": "لديك 50$ من رصيد OpenHands المجاني",
"fr": "Vous avez reçu $50 de crédits OpenHands gratuits",
"tr": "OpenHands'de $50 ücretsiz kredi kazandınız",
"de": "Sie haben $50 in kostenlosen OpenHands-Guthaben erhalten"
},
"BILLING$ERROR_WHILE_CREATING_SESSION": {
"en": "Error occurred while setting up your payment session. Please try again later.",
"ja": "お支払いセッションの設定中にエラーが発生しました。後ほど再度お試しください。",
"zh-CN": "设置支付会话时发生错误。请稍后再试。",
"zh-TW": "設置支付會話時發生錯誤。請稍後再試。",
"ko-KR": "결제 세션 설정 중 오류가 발생했습니다. 나중에 다시 시도해 주세요.",
"no": "Det oppstod en feil under oppsett av betalingsøkten. Vennligst prøv igjen senere.",
"it": "Si è verificato un errore durante la configurazione della sessione di pagamento. Si prega di riprovare più tardi.",
"pt": "Ocorreu um erro ao configurar sua sessão de pagamento. Por favor, tente novamente mais tarde.",
"es": "Se produjo un error al configurar tu sesión de pago. Por favor, inténtalo de nuevo más tarde.",
"ar": "حدث خطأ أثناء إعداد جلسة الدفع الخاصة بك. يرجى المحاولة مرة أخرى لاحقًا.",
"fr": "Une erreur s'est produite lors de la configuration de votre session de paiement. Veuillez réessayer plus tard.",
"tr": "Ödeme oturumunuz kurulurken bir hata oluştu. Lütfen daha sonra tekrar deneyin.",
"de": "Beim Einrichten Ihrer Zahlungssitzung ist ein Fehler aufgetreten. Bitte versuchen Sie es später erneut."
},
"BILLING$CLAIM_YOUR_50": {
"en": "Add a credit card with Stripe to claim your $50. <b>We won't charge you without asking first!</b>",
"ja": "Stripeでクレジットカードを追加して$50を獲得。<b>事前の確認なしで請求することはありません!</b>",
"zh-CN": "添加 Stripe 信用卡以领取 $50。<b>我们不会在未经您同意的情况下收费!</b>",
"zh-TW": "添加 Stripe 信用卡以領取 $50。<b>我們不會在未經您同意的情況下收費!</b>",
"ko-KR": "Stripe에 신용카드를 추가하여 $50를 받으세요. <b>사전 동의 없이 요금이 청구되지 않습니다!</b>",
"no": "Legg til et kredittkort med Stripe for å få $50. <b>Vi belaster deg ikke uten å spørre først!</b>",
"it": "Aggiungi una carta di credito con Stripe per ottenere $50. <b>Non ti addebiteremo nulla senza chiedere prima!</b>",
"pt": "Adicione um cartão de crédito com Stripe para receber $50. <b>Não cobraremos sem perguntar primeiro!</b>",
"es": "Añade una tarjeta de crédito con Stripe para reclamar tus $50. <b>¡No te cobraremos sin preguntarte primero!</b>",
"ar": "أضف بطاقة ائتمان مع Stripe للحصول على 50$. <b>لن نقوم بالخصم دون إذن مسبق!</b>",
"fr": "Ajoutez une carte de crédit avec Stripe pour obtenir 50$. <b>Nous ne vous facturerons pas sans vous demander d'abord !</b>",
"tr": "50$ almak için Stripe ile kredi kartı ekleyin. <b>Önce sormadan ücret almayacağız!</b>",
"de": "Fügen Sie eine Kreditkarte mit Stripe hinzu, um $50 zu erhalten. <b>Wir belasten Sie nicht ohne vorherige Zustimmung!</b>"
},
"BILLING$PROCEED_TO_STRIPE": {
"en": "Add Billing Info",
"ja": "請求情報を追加",
"zh-CN": "添加账单信息",
"zh-TW": "添加帳單資訊",
"ko-KR": "결제 정보 추가",
"no": "Legg til betalingsinformasjon",
"it": "Aggiungi informazioni di fatturazione",
"pt": "Adicionar informações de pagamento",
"es": "Añadir información de facturación",
"ar": "إضافة معلومات الفواتير",
"fr": "Ajouter les informations de facturation",
"tr": "Fatura Bilgisi Ekle",
"de": "Zahlungsinformationen hinzufügen"
},
"BILLING$YOURE_IN": {
"en": "You're in! You can start using your $50 in free credits now.",
"ja": "登録完了!$50分の無料クレジットを今すぐご利用いただけます。",
"zh-CN": "您已加入!现在可以开始使用$50的免费额度了。",
"zh-TW": "您已加入!現在可以開始使用$50的免費額度了。",
"ko-KR": "가입 완료! 지금 바로 $50 상당의 무료 크레딧을 사용하실 수 있습니다.",
"no": "Du er med! Du kan begynne å bruke dine $50 i gratis kreditter nå.",
"it": "Ci sei! Puoi iniziare a utilizzare i tuoi $50 in crediti gratuiti ora.",
"pt": "Você está dentro! Você pode começar a usar seus $50 em créditos gratuitos agora.",
"es": "¡Ya estás dentro! Puedes empezar a usar tus $50 en créditos gratuitos ahora.",
"ar": "أنت معنا! يمكنك البدء في استخدام رصيدك المجاني البالغ 50 دولارًا الآن.",
"fr": "C'est fait ! Vous pouvez commencer à utiliser vos 50 $ de crédits gratuits maintenant.",
"tr": "Başardın! Şimdi $50 değerindeki ücretsiz kredilerini kullanmaya başlayabilirsin.",
"de": "Du bist dabei! Du kannst jetzt deine $50 an kostenlosen Guthaben nutzen."
}
}

View File

@@ -1,8 +1,4 @@
import { delay, http, HttpResponse } from "msw";
import Stripe from "stripe";
const TEST_STRIPE_SECRET_KEY = "";
const PRICE_ID = "";
export const STRIPE_BILLING_HANDLERS = [
http.get("/api/billing/credits", async () => {
@@ -10,27 +6,17 @@ export const STRIPE_BILLING_HANDLERS = [
return HttpResponse.json({ credits: "100" });
}),
http.post("/api/billing/create-checkout-session", async ({ request }) => {
http.post("/api/billing/create-checkout-session", async () => {
await delay();
const body = await request.json();
return HttpResponse.json({
redirect_url: "https://stripe.com/some-checkout",
});
}),
if (body && typeof body === "object" && body.amount) {
const stripe = new Stripe(TEST_STRIPE_SECRET_KEY);
const session = await stripe.checkout.sessions.create({
line_items: [
{
price: PRICE_ID,
quantity: body.amount,
},
],
mode: "payment",
success_url: "http://localhost:3001/settings/billing/?checkout=success",
cancel_url: "http://localhost:3001/settings/billing/?checkout=cancel",
});
if (session.url) return HttpResponse.json({ redirect_url: session.url });
}
return HttpResponse.json({ message: "Invalid request" }, { status: 400 });
http.post("/api/billing/create-customer-setup-session", async () => {
await delay();
return HttpResponse.json({
redirect_url: "https://stripe.com/some-customer-setup",
});
}),
];

View File

@@ -22,12 +22,13 @@ export const MOCK_DEFAULT_USER_SETTINGS: ApiSettings | PostApiSettings = {
enable_default_condenser: DEFAULT_SETTINGS.ENABLE_DEFAULT_CONDENSER,
enable_sound_notifications: DEFAULT_SETTINGS.ENABLE_SOUND_NOTIFICATIONS,
user_consents_to_analytics: DEFAULT_SETTINGS.USER_CONSENTS_TO_ANALYTICS,
provider_tokens: DEFAULT_SETTINGS.PROVIDER_TOKENS,
};
const MOCK_USER_PREFERENCES: {
settings: ApiSettings | PostApiSettings;
settings: ApiSettings | PostApiSettings | null;
} = {
settings: MOCK_DEFAULT_USER_SETTINGS,
settings: null,
};
const conversations: Conversation[] = [
@@ -174,22 +175,24 @@ export const handlers = [
),
http.get("/api/options/config", () => {
const mockSaas = import.meta.env.VITE_MOCK_SAAS === "true";
const config: GetConfigResponse = {
APP_MODE: mockSaas ? "saas" : "oss",
GITHUB_CLIENT_ID: "fake-github-client-id",
POSTHOG_CLIENT_KEY: "fake-posthog-client-key",
STRIPE_PUBLISHABLE_KEY: "",
};
return HttpResponse.json(config);
}),
http.get("/api/settings", async () => {
await delay();
const settings: ApiSettings = {
...MOCK_USER_PREFERENCES.settings,
language: "no",
};
// @ts-expect-error - mock types
if (settings.github_token) settings.github_token_is_set = true;
const { settings } = MOCK_USER_PREFERENCES;
if (!settings) return HttpResponse.json(null, { status: 404 });
if (Object.keys(settings.provider_tokens).length > 0)
settings.github_token_is_set = true;
return HttpResponse.json(settings);
}),
@@ -201,17 +204,19 @@ export const handlers = [
if (typeof body === "object") {
newSettings = { ...body };
if (newSettings.unset_github_token) {
newSettings.github_token = undefined;
newSettings.provider_tokens = { github: "", gitlab: "" };
newSettings.github_token_is_set = false;
delete newSettings.unset_github_token;
}
}
MOCK_USER_PREFERENCES.settings = {
const fullSettings = {
...MOCK_DEFAULT_USER_SETTINGS,
...MOCK_USER_PREFERENCES.settings,
...newSettings,
};
MOCK_USER_PREFERENCES.settings = fullSettings;
return HttpResponse.json(null, { status: 200 });
}

View File

@@ -24,7 +24,10 @@ function Home() {
});
return (
<div className="bg-base-secondary h-full rounded-xl flex flex-col items-center justify-center relative overflow-y-auto px-2">
<div
data-testid="home-screen"
className="bg-base-secondary h-full rounded-xl flex flex-col items-center justify-center relative overflow-y-auto px-2"
>
<HeroHeading />
<div className="flex flex-col gap-8 w-full md:w-[600px] items-center">
<div className="flex flex-col gap-2 w-full">

View File

@@ -1,5 +1,13 @@
import React from "react";
import { useRouteError, isRouteErrorResponse, Outlet } from "react-router";
import {
useRouteError,
isRouteErrorResponse,
Outlet,
useNavigate,
useLocation,
useSearchParams,
} from "react-router";
import { useTranslation } from "react-i18next";
import i18n from "#/i18n";
import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url";
import { useIsAuthed } from "#/hooks/query/use-is-authed";
@@ -10,6 +18,10 @@ import { AnalyticsConsentFormModal } from "#/components/features/analytics/analy
import { useSettings } from "#/hooks/query/use-settings";
import { useAuth } from "#/context/auth-context";
import { useMigrateUserConsent } from "#/hooks/use-migrate-user-consent";
import { useBalance } from "#/hooks/query/use-balance";
import { SetupPaymentModal } from "#/components/features/payment/setup-payment-modal";
import { BILLING_SETTINGS } from "#/utils/feature-flags";
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
export function ErrorBoundary() {
const error = useRouteError();
@@ -44,11 +56,14 @@ export function ErrorBoundary() {
}
export default function MainApp() {
const navigate = useNavigate();
const { pathname } = useLocation();
const [searchParams] = useSearchParams();
const { githubTokenIsSet } = useAuth();
const { data: settings } = useSettings();
const { error, isFetching } = useBalance();
const { migrateUserConsent } = useMigrateUserConsent();
const [consentFormIsOpen, setConsentFormIsOpen] = React.useState(false);
const { t } = useTranslation();
const config = useConfig();
const {
@@ -62,6 +77,8 @@ export default function MainApp() {
gitHubClientId: config.data?.GITHUB_CLIENT_ID || null,
});
const [consentFormIsOpen, setConsentFormIsOpen] = React.useState(false);
React.useEffect(() => {
if (settings?.LANGUAGE) {
i18n.changeLanguage(settings.LANGUAGE);
@@ -84,6 +101,17 @@ export default function MainApp() {
});
}, []);
React.useEffect(() => {
// Don't allow users to use the app if it 402s
if (error?.status === 402 && pathname !== "/") {
navigate("/");
} else if (!isFetching && searchParams.get("free_credits") === "success") {
displaySuccessToast(t("BILLING$YOURE_IN"));
searchParams.delete("free_credits");
navigate("/");
}
}, [error?.status, pathname, isFetching]);
const userIsAuthed = !!isAuthed && !authError;
const renderWaitlistModal =
!isFetchingAuth && !userIsAuthed && config.data?.APP_MODE === "saas";
@@ -116,6 +144,10 @@ export default function MainApp() {
}}
/>
)}
{BILLING_SETTINGS() &&
config.data?.APP_MODE === "saas" &&
settings?.IS_NEW_USER && <SetupPaymentModal />}
</div>
);
}

View File

@@ -61,7 +61,10 @@ function AccountSettings() {
if (isSuccess) {
return (
isCustomModel(resources.models, settings.LLM_MODEL) ||
hasAdvancedSettingsSet(settings)
hasAdvancedSettingsSet({
...settings,
PROVIDER_TOKENS: settings.PROVIDER_TOKENS || {},
})
);
}
@@ -128,37 +131,42 @@ function AccountSettings() {
: llmBaseUrl;
const finalLlmApiKey = shouldHandleSpecialSaasCase ? undefined : llmApiKey;
saveSettings(
{
github_token:
formData.get("github-token-input")?.toString() || undefined,
LANGUAGE: languageValue,
user_consents_to_analytics: userConsentsToAnalytics,
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
LLM_MODEL: finalLlmModel,
LLM_BASE_URL: finalLlmBaseUrl,
LLM_API_KEY: finalLlmApiKey,
AGENT: formData.get("agent-input")?.toString(),
SECURITY_ANALYZER:
formData.get("security-analyzer-input")?.toString() || "",
REMOTE_RUNTIME_RESOURCE_FACTOR:
remoteRuntimeResourceFactor ||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
CONFIRMATION_MODE: confirmationModeIsEnabled,
const githubToken = formData.get("github-token-input")?.toString();
const newSettings = {
github_token: githubToken,
provider_tokens: githubToken
? {
github: githubToken,
gitlab: "",
}
: undefined,
LANGUAGE: languageValue,
user_consents_to_analytics: userConsentsToAnalytics,
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
LLM_MODEL: finalLlmModel,
LLM_BASE_URL: finalLlmBaseUrl,
LLM_API_KEY: finalLlmApiKey,
AGENT: formData.get("agent-input")?.toString(),
SECURITY_ANALYZER:
formData.get("security-analyzer-input")?.toString() || "",
REMOTE_RUNTIME_RESOURCE_FACTOR:
remoteRuntimeResourceFactor ||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
CONFIRMATION_MODE: confirmationModeIsEnabled,
};
saveSettings(newSettings, {
onSuccess: () => {
handleCaptureConsent(userConsentsToAnalytics);
displaySuccessToast("Settings saved");
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
},
{
onSuccess: () => {
handleCaptureConsent(userConsentsToAnalytics);
displaySuccessToast("Settings saved");
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
},
onError: (error) => {
const errorMessage = retrieveAxiosErrorMessage(error);
displayErrorToast(errorMessage);
},
onError: (error) => {
const errorMessage = retrieveAxiosErrorMessage(error);
displayErrorToast(errorMessage);
},
);
});
};
const handleReset = () => {

View File

@@ -62,13 +62,12 @@ const messageActions = {
let successPrediction = "";
if (message.args.task_completed === "partial") {
successPrediction =
"The agent thinks that the task was **completed partially**.";
"I believe that the task was **completed partially**.";
} else if (message.args.task_completed === "false") {
successPrediction =
"The agent thinks that the task was **not completed**.";
successPrediction = "I believe that the task was **not completed**.";
} else if (message.args.task_completed === "true") {
successPrediction =
"The agent thinks that the task was **completed successfully**.";
"I believe that the task was **completed successfully**.";
}
if (successPrediction) {
// if final_thought is not empty, add a new line before the success prediction

View File

@@ -15,6 +15,11 @@ export const DEFAULT_SETTINGS: Settings = {
ENABLE_DEFAULT_CONDENSER: true,
ENABLE_SOUND_NOTIFICATIONS: false,
USER_CONSENTS_TO_ANALYTICS: false,
PROVIDER_TOKENS: {
github: "",
gitlab: "",
},
IS_NEW_USER: true,
};
/**

View File

@@ -21,7 +21,6 @@ export interface InitConfig {
LLM_MODEL: string;
};
token?: string;
github_token?: string;
latest_event_id?: unknown; // Not sure what this is
}

View File

@@ -1,3 +1,5 @@
export type Provider = "github" | "gitlab";
export type Settings = {
LLM_MODEL: string;
LLM_BASE_URL: string;
@@ -11,6 +13,8 @@ export type Settings = {
ENABLE_DEFAULT_CONDENSER: boolean;
ENABLE_SOUND_NOTIFICATIONS: boolean;
USER_CONSENTS_TO_ANALYTICS: boolean | null;
PROVIDER_TOKENS: Record<Provider, string>;
IS_NEW_USER?: boolean;
};
export type ApiSettings = {
@@ -26,16 +30,17 @@ export type ApiSettings = {
enable_default_condenser: boolean;
enable_sound_notifications: boolean;
user_consents_to_analytics: boolean | null;
provider_tokens: Record<Provider, string>;
};
export type PostSettings = Settings & {
github_token: string;
provider_tokens: Record<Provider, string>;
unset_github_token: boolean;
user_consents_to_analytics: boolean | null;
};
export type PostApiSettings = ApiSettings & {
github_token: string;
provider_tokens: Record<Provider, string>;
unset_github_token: boolean;
user_consents_to_analytics: boolean | null;
};

View File

@@ -59,6 +59,18 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
ENABLE_DEFAULT_CONDENSER,
} = extractAdvancedFormData(formData);
// Extract provider tokens
const githubToken = formData.get("github-token")?.toString();
const gitlabToken = formData.get("gitlab-token")?.toString();
const providerTokens: Record<string, string> = {};
if (githubToken) {
providerTokens.github = githubToken;
}
if (gitlabToken) {
providerTokens.gitlab = gitlabToken;
}
return {
LLM_MODEL: CUSTOM_LLM_MODEL || LLM_MODEL,
LLM_API_KEY,
@@ -68,5 +80,6 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
CONFIRMATION_MODE,
SECURITY_ANALYZER,
ENABLE_DEFAULT_CONDENSER,
PROVIDER_TOKENS: providerTokens,
};
};

View File

@@ -17,45 +17,7 @@ export default {
tertiary: "#454545", // gray, used for inputs
"tertiary-light": "#B7BDC2", // lighter gray, used for borders and placeholder text
content: "#ECEDEE", // light gray, used mostly for text
},
},
animation: {
enter: "toastIn 400ms cubic-bezier(0.21, 1.02, 0.73, 1)",
leave: "toastOut 100ms ease-in forwards",
},
keyframes: {
toastIn: {
"0%": {
opacity: "0",
transform: "translateY(-100%) scale(0.8)",
},
"80%": {
opacity: "1",
transform: "translateY(0) scale(1.02)",
},
"100%": {
opacity: "1",
transform: "translateY(0) scale(1)",
},
},
toastOut: {
"0%": {
opacity: "1",
transform: "translateY(0) scale(1)",
},
"100%": {
opacity: "0",
transform: "translateY(-100%) scale(0.9)",
},
},
colors: {
primary: "#C9B974", // nice yellow
base: "#171717", // dark background (neutral-900)
"base-secondary": "#262626", // lighter background (neutral-800); also used for tooltips
danger: "#E76A5E",
success: "#A5E75E",
tertiary: "#454545", // gray, used for inputs
"tertiary-light": "#B7BDC2", // lighter gray, used for borders and placeholder text
"content-2": "#F9FBFE",
},
},
},

View File

@@ -6,7 +6,6 @@ load_dotenv()
from openhands.agenthub import ( # noqa: E402
browsing_agent,
codeact_agent,
delegator_agent,
dummy_agent,
visualbrowsing_agent,
)
@@ -15,7 +14,6 @@ from openhands.controller.agent import Agent # noqa: E402
__all__ = [
'Agent',
'codeact_agent',
'delegator_agent',
'dummy_agent',
'browsing_agent',
'visualbrowsing_agent',

View File

@@ -1,8 +1,6 @@
import json
import os
from collections import deque
import openhands
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
@@ -72,23 +70,17 @@ class CodeActAgent(Agent):
codeact_enable_browsing=self.config.codeact_enable_browsing,
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
llm=self.llm,
)
logger.debug(
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
f"TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}"
)
self.prompt_manager = PromptManager(
microagent_dir=os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'microagents',
)
if self.config.enable_prompt_extensions
else None,
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
disabled_microagents=self.config.disabled_microagents,
)
# Create a ConversationMemory instance
self.conversation_memory = ConversationMemory(self.prompt_manager)
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
self.condenser = Condenser.from_config(self.config.condenser)
logger.debug(f'Using condenser: {type(self.condenser)}')
@@ -168,7 +160,7 @@ class CodeActAgent(Agent):
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')
# Use conversation_memory to process events instead of calling events_to_messages directly
# Use ConversationMemory to process initial messages
messages = self.conversation_memory.process_initial_messages(
with_caching=self.llm.is_caching_prompt_active()
)
@@ -180,12 +172,12 @@ class CodeActAgent(Agent):
f'Processing {len(events)} events from a total of {len(state.history)} events'
)
# Use ConversationMemory to process events
messages = self.conversation_memory.process_events(
condensed_history=events,
initial_messages=messages,
max_message_chars=self.llm.config.max_message_chars,
vision_is_active=self.llm.vision_is_active(),
enable_som_visual_browsing=self.config.enable_som_visual_browsing,
)
messages = self._enhance_messages(messages)
@@ -216,14 +208,7 @@ class CodeActAgent(Agent):
# compose the first user message with examples
self.prompt_manager.add_examples_to_initial_message(msg)
# and/or repo/runtime info
if self.config.enable_prompt_extensions:
self.prompt_manager.add_info_to_initial_message(msg)
# enhance the user message with additional context based on keywords matched
if msg.role == 'user':
self.prompt_manager.enhance_message(msg)
elif msg.role == 'user':
# Add double newline between consecutive user messages
if prev_role == 'user' and len(msg.content) > 0:
# Find the first TextContent in the message to add newlines

View File

@@ -12,13 +12,13 @@ from litellm import (
from openhands.agenthub.codeact_agent.tools import (
BrowserTool,
CmdRunTool,
FinishTool,
IPythonTool,
LLMBasedFileEditTool,
StrReplaceEditorTool,
ThinkTool,
WebReadTool,
create_cmd_run_tool,
create_str_replace_editor_tool,
)
from openhands.core.exceptions import (
FunctionCallNotExistsError,
@@ -39,6 +39,7 @@ from openhands.events.action import (
)
from openhands.events.event import FileEditSource, FileReadSource
from openhands.events.tool import ToolCallMetadata
from openhands.llm import LLM
def combine_thought(action: Action, thought: str) -> Action:
@@ -80,7 +81,7 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
# CmdRunTool (Bash)
# ================================================
if tool_call.function.name == CmdRunTool['function']['name']:
if tool_call.function.name == create_cmd_run_tool()['function']['name']:
if 'command' not in arguments:
raise FunctionCallValidationError(
f'Missing required argument "command" in tool call {tool_call.function.name}'
@@ -131,7 +132,10 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
start=arguments.get('start', 1),
end=arguments.get('end', -1),
)
elif tool_call.function.name == StrReplaceEditorTool['function']['name']:
elif (
tool_call.function.name
== create_str_replace_editor_tool()['function']['name']
):
if 'command' not in arguments:
raise FunctionCallValidationError(
f'Missing required argument "command" in tool call {tool_call.function.name}'
@@ -219,8 +223,22 @@ def get_tools(
codeact_enable_browsing: bool = False,
codeact_enable_llm_editor: bool = False,
codeact_enable_jupyter: bool = False,
llm: LLM | None = None,
) -> list[ChatCompletionToolParam]:
tools = [CmdRunTool, ThinkTool, FinishTool]
SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS = ['gpt-', 'o3', 'o1']
use_simplified_tool_desc = False
if llm is not None:
use_simplified_tool_desc = any(
model_substr in llm.config.model
for model_substr in SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS
)
tools = [
create_cmd_run_tool(use_simplified_description=use_simplified_tool_desc),
ThinkTool,
FinishTool,
]
if codeact_enable_browsing:
tools.append(WebReadTool)
tools.append(BrowserTool)
@@ -229,5 +247,9 @@ def get_tools(
if codeact_enable_llm_editor:
tools.append(LLMBasedFileEditTool)
else:
tools.append(StrReplaceEditorTool)
tools.append(
create_str_replace_editor_tool(
use_simplified_description=use_simplified_tool_desc
)
)
return tools

View File

@@ -1,6 +1,6 @@
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
{% if repository_instructions -%}
@@ -20,6 +20,8 @@ When starting a web server, use the corresponding ports. You should also
set any options to allow iframes and CORS requests, and allow the server to
be accessed from any host (e.g. 0.0.0.0).
{% endif %}
{% if runtime_info.additional_agent_instructions %}
{{ runtime_info.additional_agent_instructions }}
{% endif %}
</RUNTIME_INFORMATION>
{% endif %}

View File

@@ -1,8 +1,8 @@
{% for agent_info in triggered_agents %}
<EXTRA_INFO>
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
It may or may not be relevant to the user's request.
{{ agent_info.agent.content }}
{{ agent_info.content }}
</EXTRA_INFO>
{% endfor %}

View File

@@ -1,19 +1,19 @@
from .bash import CmdRunTool
from .bash import create_cmd_run_tool
from .browser import BrowserTool
from .finish import FinishTool
from .ipython import IPythonTool
from .llm_based_edit import LLMBasedFileEditTool
from .str_replace_editor import StrReplaceEditorTool
from .str_replace_editor import create_str_replace_editor_tool
from .think import ThinkTool
from .web_read import WebReadTool
__all__ = [
'BrowserTool',
'CmdRunTool',
'create_cmd_run_tool',
'FinishTool',
'IPythonTool',
'LLMBasedFileEditTool',
'StrReplaceEditorTool',
'create_str_replace_editor_tool',
'WebReadTool',
'ThinkTool',
]

View File

@@ -1,6 +1,6 @@
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
_DETAILED_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
### Command Execution
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, use `&&` or `;` to chain them together.
@@ -22,25 +22,39 @@ _BASH_DESCRIPTION = """Execute a bash command in the terminal within a persisten
* Output truncation: If the output exceeds a maximum length, it will be truncated before being returned.
"""
CmdRunTool = ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='execute_bash',
description=_BASH_DESCRIPTION,
parameters={
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
},
'is_input': {
'type': 'string',
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
'enum': ['true', 'false'],
_SIMPLIFIED_BASH_DESCRIPTION = """Execute a bash command in the terminal.
* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
* Interact with running process: If a bash command returns exit code `-1`, this means the process is not yet finished. By setting `is_input` to `true`, the assistant can interact with the running process and send empty `command` to retrieve any additional logs, or send additional text (set `command` to the text) to STDIN of the running process, or send command like `C-c` (Ctrl+C), `C-d` (Ctrl+D), `C-z` (Ctrl+Z) to interrupt the process.
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together."""
def create_cmd_run_tool(
use_simplified_description: bool = False,
) -> ChatCompletionToolParam:
description = (
_SIMPLIFIED_BASH_DESCRIPTION
if use_simplified_description
else _DETAILED_BASH_DESCRIPTION
)
return ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='execute_bash',
description=description,
parameters={
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
},
'is_input': {
'type': 'string',
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
'enum': ['true', 'false'],
},
},
'required': ['command'],
},
'required': ['command'],
},
),
)
),
)

View File

@@ -1,6 +1,6 @@
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
_DETAILED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
* State is persistent across command calls and discussions with the user
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
* The `create` command cannot be used if the specified `path` already exists as a file
@@ -31,46 +31,73 @@ CRITICAL REQUIREMENTS FOR USING THIS TOOL:
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
"""
StrReplaceEditorTool = ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='str_replace_editor',
description=_STR_REPLACE_EDITOR_DESCRIPTION,
parameters={
'type': 'object',
'properties': {
'command': {
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'],
'type': 'string',
},
'path': {
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
'type': 'string',
},
'file_text': {
'description': 'Required parameter of `create` command, with the content of the file to be created.',
'type': 'string',
},
'old_str': {
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
'type': 'string',
},
'new_str': {
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
'type': 'string',
},
'insert_line': {
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
'type': 'integer',
},
'view_range': {
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
'items': {'type': 'integer'},
'type': 'array',
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
* State is persistent across command calls and discussions with the user
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
* The `create` command cannot be used if the specified `path` already exists as a file
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
* The `undo_edit` command will revert the last edit made to the file at `path`
Notes for using the `str_replace` command:
* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!
* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique
* The `new_str` parameter should contain the edited lines that should replace the `old_str`
"""
def create_str_replace_editor_tool(
use_simplified_description: bool = False,
) -> ChatCompletionToolParam:
description = (
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION
if use_simplified_description
else _DETAILED_STR_REPLACE_EDITOR_DESCRIPTION
)
return ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='str_replace_editor',
description=description,
parameters={
'type': 'object',
'properties': {
'command': {
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
'enum': [
'view',
'create',
'str_replace',
'insert',
'undo_edit',
],
'type': 'string',
},
'path': {
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
'type': 'string',
},
'file_text': {
'description': 'Required parameter of `create` command, with the content of the file to be created.',
'type': 'string',
},
'old_str': {
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
'type': 'string',
},
'new_str': {
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
'type': 'string',
},
'insert_line': {
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
'type': 'integer',
},
'view_range': {
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
'items': {'type': 'integer'},
'type': 'array',
},
},
'required': ['command', 'path'],
},
'required': ['command', 'path'],
},
),
)
),
)

View File

@@ -1,4 +0,0 @@
from openhands.agenthub.delegator_agent.agent import DelegatorAgent
from openhands.controller.agent import Agent
Agent.register('DelegatorAgent', DelegatorAgent)

View File

@@ -1,87 +0,0 @@
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction
from openhands.events.observation import AgentDelegateObservation, Observation
from openhands.llm.llm import LLM
class DelegatorAgent(Agent):
VERSION = '1.0'
"""
The Delegator Agent is responsible for delegating tasks to other agents based on the current task.
"""
current_delegate: str = ''
def __init__(self, llm: LLM, config: AgentConfig):
"""Initialize the Delegator Agent with an LLM
Parameters:
- llm (LLM): The llm to be used by this agent
"""
super().__init__(llm, config)
def step(self, state: State) -> Action:
"""Checks to see if current step is completed, returns AgentFinishAction if True.
Otherwise, delegates the task to the next agent in the pipeline.
Parameters:
- state (State): The current state given the previous actions and observations
Returns:
- AgentFinishAction: If the last state was 'completed', 'verified', or 'abandoned'
- AgentDelegateAction: The next agent to delegate the task to
"""
if self.current_delegate == '':
self.current_delegate = 'study'
task, _ = state.get_current_user_intent()
return AgentDelegateAction(
agent='StudyRepoForTaskAgent', inputs={'task': task}
)
# last observation in history should be from the delegate
last_observation = None
for event in reversed(state.history):
if isinstance(event, Observation):
last_observation = event
break
if not isinstance(last_observation, AgentDelegateObservation):
raise Exception('Last observation is not an AgentDelegateObservation')
goal, _ = state.get_current_user_intent()
if self.current_delegate == 'study':
self.current_delegate = 'coder'
return AgentDelegateAction(
agent='CoderAgent',
inputs={
'task': goal,
'summary': last_observation.outputs['summary'],
},
)
elif self.current_delegate == 'coder':
self.current_delegate = 'verifier'
return AgentDelegateAction(
agent='VerifierAgent',
inputs={
'task': goal,
},
)
elif self.current_delegate == 'verifier':
if (
'completed' in last_observation.outputs
and last_observation.outputs['completed']
):
return AgentFinishAction()
else:
self.current_delegate = 'coder'
return AgentDelegateAction(
agent='CoderAgent',
inputs={
'task': goal,
'summary': last_observation.outputs['summary'],
},
)
else:
raise Exception('Invalid delegate state')

View File

@@ -202,6 +202,7 @@ Note:
tabs = ''
last_obs = None
last_action = None
set_of_marks = None # Initialize set_of_marks to None
if len(state.history) == 1:
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
@@ -217,6 +218,9 @@ Note:
# agent has responded, task finished.
return AgentFinishAction(outputs={'content': event.content})
elif isinstance(event, Observation):
# Only process BrowserOutputObservation and skip other observation types
if not isinstance(event, BrowserOutputObservation):
continue
last_obs = event
if len(prev_actions) >= 1: # ignore noop()

View File

@@ -29,7 +29,12 @@ from openhands.core.exceptions import (
from openhands.core.logger import LOG_ALL_EVENTS
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events import (
EventSource,
EventStream,
EventStreamSubscriber,
RecallType,
)
from openhands.events.action import (
Action,
ActionConfirmationStatus,
@@ -42,6 +47,7 @@ from openhands.events.action import (
MessageAction,
NullAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event
from openhands.events.observation import (
AgentCondensationObservation,
@@ -89,7 +95,7 @@ class AgentController:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
sid: str = 'default',
sid: str | None = None,
confirmation_mode: bool = False,
initial_state: State | None = None,
is_delegate: bool = False,
@@ -116,7 +122,7 @@ class AgentController:
status_callback: Optional callback function to handle status updates.
replay_events: A list of logs to replay.
"""
self.id = sid
self.id = sid or event_stream.sid
self.agent = agent
self.headless_mode = headless_mode
self.is_delegate = is_delegate
@@ -287,8 +293,14 @@ class AgentController:
return True
return False
if isinstance(event, Observation):
if isinstance(event, NullObservation) or isinstance(
event, AgentStateChangedObservation
if (
isinstance(event, NullObservation)
and event.cause is not None
and event.cause > 0
):
return True
if isinstance(event, AgentStateChangedObservation) or isinstance(
event, NullObservation
):
return False
return True
@@ -388,6 +400,7 @@ class AgentController:
if observation.llm_metrics is not None:
self.agent.llm.metrics.merge(observation.llm_metrics)
# this happens for runnable actions and microagent actions
if self._pending_action and self._pending_action.id == observation.cause:
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
return
@@ -431,6 +444,25 @@ class AgentController:
'debug',
f'Extended max iterations to {self.state.max_iterations} after user message',
)
# try to retrieve microagents relevant to the user message
# set pending_action while we search for information
# if this is the first user message for this agent, matters for the microagent info type
first_user_message = self._first_user_message()
is_first_user_message = (
action.id == first_user_message.id if first_user_message else False
)
recall_type = (
RecallType.WORKSPACE_CONTEXT
if is_first_user_message
else RecallType.KNOWLEDGE
)
recall_action = RecallAction(query=action.content, recall_type=recall_type)
self._pending_action = recall_action
# this is source=USER because the user message is the trigger for the microagent retrieval
self.event_stream.add_event(recall_action, EventSource.USER)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
@@ -438,6 +470,7 @@ class AgentController:
def _reset(self) -> None:
"""Resets the agent controller"""
# Runnable actions need an Observation
# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
@@ -459,6 +492,8 @@ class AgentController:
obs._cause = self._pending_action.id # type: ignore[attr-defined]
self.event_stream.add_event(obs, EventSource.AGENT)
# NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls
# reset the pending action, this will be called when the agent is STOPPED or ERROR
self._pending_action = None
self.agent.reset()
@@ -1146,3 +1181,26 @@ class AgentController:
result = event.agent_state == AgentState.RUNNING
return result
return False
def _first_user_message(self) -> MessageAction | None:
"""
Get the first user message for this agent.
For regular agents, this is the first user message from the beginning (start_id=0).
For delegate agents, this is the first user message after the delegate's start_id.
Returns:
MessageAction | None: The first user message, or None if no user message found
"""
# Find the first user message from the appropriate starting point
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
# Get and return the first user message
return next(
(
e
for e in user_messages
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)

View File

@@ -103,8 +103,9 @@ class StuckDetector:
return True
# scenario 5: context window error loop
if self._is_stuck_context_window_error(filtered_history):
return True
if len(filtered_history) >= 10:
if self._is_stuck_context_window_error(filtered_history):
return True
return False
@@ -134,7 +135,7 @@ class StuckDetector:
# it takes 3 actions and 3 observations to detect a loop
# check if the last three actions are the same and result in errors
if len(last_actions) < 4 or len(last_observations) < 4:
if len(last_actions) < 3 or len(last_observations) < 3:
return False
# are the last three actions the "same"?
@@ -333,12 +334,12 @@ class StuckDetector:
if isinstance(event, AgentCondensationObservation)
]
# Need at least 3 condensation events to detect a loop
if len(condensation_events) < 3:
# Need at least 10 condensation events to detect a loop
if len(condensation_events) < 10:
return False
# Get the last 3 condensation events
last_condensation_events = condensation_events[-3:]
# Get the last 10 condensation events
last_condensation_events = condensation_events[-10:]
# Check if there are any non-condensation events between them
for i in range(len(last_condensation_events) - 1):

View File

@@ -17,6 +17,7 @@ from openhands.core.schema import AgentState
from openhands.core.setup import (
create_agent,
create_controller,
create_memory,
create_runtime,
initialize_repository_for_runtime,
)
@@ -170,13 +171,22 @@ async def main(loop: asyncio.AbstractEventLoop):
await runtime.connect()
# Initialize repository if needed
repo_directory = None
if config.sandbox.selected_repo:
initialize_repository_for_runtime(
repo_directory = initialize_repository_for_runtime(
runtime,
agent=agent,
selected_repository=config.sandbox.selected_repo,
)
# when memory is created, it will load the microagents from the selected repository
memory = create_memory(
runtime=runtime,
event_stream=event_stream,
sid=sid,
selected_repository=config.sandbox.selected_repo,
repo_directory=repo_directory,
)
if initial_user_action:
# If there's an initial user action, enqueue it and do not prompt again
event_stream.add_event(initial_user_action, EventSource.USER)
@@ -185,7 +195,7 @@ async def main(loop: asyncio.AbstractEventLoop):
asyncio.create_task(prompt_for_next_task())
await run_agent_until_done(
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
)

View File

@@ -15,6 +15,7 @@ class SandboxConfig(BaseModel):
timeout: The timeout for the default sandbox action execution.
remote_runtime_init_timeout: The timeout for the remote runtime to start.
remote_runtime_api_timeout: The timeout for the remote runtime API requests.
remote_runtime_enable_retries: Whether to enable retries (on recoverable errors like requests.ConnectionError) for the remote runtime API requests.
enable_auto_lint: Whether to enable auto-lint.
use_host_network: Whether to use the host network.
runtime_binding_address: The binding address for the runtime ports. It specifies which network interface on the host machine Docker should bind the runtime ports to.
@@ -53,7 +54,7 @@ class SandboxConfig(BaseModel):
timeout: int = Field(default=120)
remote_runtime_init_timeout: int = Field(default=180)
remote_runtime_api_timeout: int = Field(default=10)
remote_runtime_enable_retries: bool = Field(default=False)
remote_runtime_enable_retries: bool = Field(default=True)
remote_runtime_class: str | None = Field(
default=None
) # can be "None" (default to gvisor) or "sysbox" (support docker inside runtime + more stable)

View File

@@ -240,7 +240,7 @@ class SensitiveDataFilter(logging.Filter):
if (
len(value) > 2
and value != 'default'
and any(s in key_upper for s in ('SECRET', 'KEY', 'CODE', 'TOKEN'))
and any(s in key_upper for s in ('SECRET', '_KEY', '_CODE', '_TOKEN'))
):
sensitive_values.append(value)

View File

@@ -3,12 +3,14 @@ import asyncio
from openhands.controller import AgentController
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
async def run_agent_until_done(
controller: AgentController,
runtime: Runtime,
memory: Memory,
end_states: list[AgentState],
):
"""
@@ -37,6 +39,7 @@ async def run_agent_until_done(
runtime.status_callback = status_callback
controller.status_callback = status_callback
memory.status_callback = status_callback
while controller.state.agent_state not in end_states:
await asyncio.sleep(1)

View File

@@ -18,6 +18,7 @@ from openhands.core.schema import AgentState
from openhands.core.setup import (
create_agent,
create_controller,
create_memory,
create_runtime,
generate_sid,
initialize_repository_for_runtime,
@@ -29,6 +30,7 @@ from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.io import read_input, read_task
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
@@ -51,6 +53,7 @@ async def run_controller(
exit_on_message: bool = False,
fake_user_response_fn: FakeUserResponseFunc | None = None,
headless_mode: bool = True,
memory: Memory | None = None,
) -> State | None:
"""Main coroutine to run the agent controller with task input flexibility.
@@ -93,6 +96,8 @@ async def run_controller(
if agent is None:
agent = create_agent(config)
# when the runtime is created, it will be connected and clone the selected repository
repo_directory = None
if runtime is None:
runtime = create_runtime(
config,
@@ -105,14 +110,23 @@ async def run_controller(
# Initialize repository if needed
if config.sandbox.selected_repo:
initialize_repository_for_runtime(
repo_directory = initialize_repository_for_runtime(
runtime,
agent=agent,
selected_repository=config.sandbox.selected_repo,
)
event_stream = runtime.event_stream
# when memory is created, it will load the microagents from the selected repository
if memory is None:
memory = create_memory(
runtime=runtime,
event_stream=event_stream,
sid=sid,
selected_repository=config.sandbox.selected_repo,
repo_directory=repo_directory,
)
replay_events: list[Event] | None = None
if config.replay_trajectory_path:
logger.info('Trajectory replay is enabled')
@@ -172,7 +186,7 @@ async def run_controller(
]
try:
await run_agent_until_done(controller, runtime, end_states)
await run_agent_until_done(controller, runtime, memory, end_states)
except Exception as e:
logger.error(f'Exception in main loop: {e}')

View File

@@ -82,5 +82,8 @@ class ActionTypeSchema(BaseModel):
SEND_PR: str = Field(default='send_pr')
"""Send a PR to github."""
RECALL: str = Field(default='recall')
"""Retrieves content from a user workspace, microagent, or other source."""
ActionType = ActionTypeSchema()

View File

@@ -49,5 +49,8 @@ class ObservationTypeSchema(BaseModel):
CONDENSE: str = Field(default='condense')
"""Result of a condensation operation."""
RECALL: str = Field(default='recall')
"""Result of a recall operation. This can be the workspace context, a microagent, or other types of information."""
ObservationType = ObservationTypeSchema()

View File

@@ -1,7 +1,7 @@
import hashlib
import os
import uuid
from typing import Tuple, Type
from typing import Callable, Tuple, Type
from pydantic import SecretStr
@@ -16,6 +16,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroAgent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
@@ -83,7 +84,6 @@ def create_runtime(
def initialize_repository_for_runtime(
runtime: Runtime,
agent: Agent | None = None,
selected_repository: str | None = None,
github_token: SecretStr | None = None,
) -> str | None:
@@ -91,7 +91,6 @@ def initialize_repository_for_runtime(
Args:
runtime: The runtime to initialize the repository for.
agent: (optional) The agent to load microagents for.
selected_repository: (optional) The GitHub repository to use.
github_token: (optional) The GitHub token to use.
@@ -99,10 +98,10 @@ def initialize_repository_for_runtime(
The repository directory path if a repository was cloned, None otherwise.
"""
# clone selected repository if provided
repo_directory = None
github_token = (
SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token
)
repo_directory = None
if selected_repository and github_token:
logger.debug(f'Selected repository {selected_repository}.')
repo_directory = runtime.clone_repo(
@@ -111,16 +110,47 @@ def initialize_repository_for_runtime(
None,
)
# load microagents from selected repository
if agent and agent.prompt_manager and selected_repository and repo_directory:
agent.prompt_manager.set_runtime_info(runtime)
return repo_directory
def create_memory(
runtime: Runtime,
event_stream: EventStream,
sid: str,
selected_repository: str | None = None,
repo_directory: str | None = None,
status_callback: Callable | None = None,
) -> Memory:
"""Create a memory for the agent to use.
Args:
runtime: The runtime to use.
event_stream: The event stream it will subscribe to.
sid: The session id.
selected_repository: The repository to clone and start with, if any.
repo_directory: The repository directory, if any.
status_callback: Optional callback function to handle status updates.
"""
memory = Memory(
event_stream=event_stream,
sid=sid,
status_callback=status_callback,
)
if runtime:
# sets available hosts
memory.set_runtime_info(runtime)
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
selected_repository
)
agent.prompt_manager.load_microagents(microagents)
agent.prompt_manager.set_repository_info(selected_repository, repo_directory)
memory.load_user_workspace_microagents(microagents)
return repo_directory
if selected_repository and repo_directory:
memory.set_repository_info(selected_repository, repo_directory)
return memory
def create_agent(config: AppConfig) -> Agent:

View File

@@ -1,4 +1,4 @@
from openhands.events.event import Event, EventSource
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.stream import EventStream, EventStreamSubscriber
__all__ = [
@@ -6,4 +6,5 @@ __all__ = [
'EventSource',
'EventStream',
'EventStreamSubscriber',
'RecallType',
]

View File

@@ -6,6 +6,7 @@ from openhands.events.action.agent import (
AgentSummarizeAction,
AgentThinkAction,
ChangeAgentStateAction,
RecallAction,
)
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
@@ -35,4 +36,5 @@ __all__ = [
'MessageAction',
'ActionConfirmationStatus',
'AgentThinkAction',
'RecallAction',
]

View File

@@ -4,6 +4,7 @@ from typing import Any
from openhands.core.schema import ActionType
from openhands.events.action.action import Action
from openhands.events.event import RecallType
@dataclass
@@ -106,3 +107,22 @@ class AgentDelegateAction(Action):
@property
def message(self) -> str:
return f"I'm asking {self.agent} for help with this task."
@dataclass
class RecallAction(Action):
"""This action is used for retrieving content, e.g., from the global directory or user workspace."""
recall_type: RecallType
query: str = ''
thought: str = ''
action: str = ActionType.RECALL
@property
def message(self) -> str:
return f'Retrieving content for: {self.query[:50]}'
def __str__(self) -> str:
ret = '**RecallAction**\n'
ret += f'QUERY: {self.query[:50]}'
return ret

View File

@@ -22,6 +22,16 @@ class FileReadSource(str, Enum):
DEFAULT = 'default'
class RecallType(str, Enum):
"""The type of information that can be retrieved from microagents."""
WORKSPACE_CONTEXT = 'workspace_context'
"""Workspace context (repo instructions, runtime, etc.)"""
KNOWLEDGE = 'knowledge'
"""A knowledge microagent."""
@dataclass
class Event:
INVALID_ID = -1

View File

@@ -1,7 +1,9 @@
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
AgentThinkObservation,
RecallObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
@@ -40,4 +42,6 @@ __all__ = [
'SuccessObservation',
'UserRejectObservation',
'AgentCondensationObservation',
'RecallObservation',
'RecallType',
]

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from openhands.core.schema import ObservationType
from openhands.events.event import RecallType
from openhands.events.observation.observation import Observation
@@ -40,3 +41,90 @@ class AgentThinkObservation(Observation):
@property
def message(self) -> str:
return self.content
@dataclass
class MicroagentKnowledge:
"""
Represents knowledge from a triggered microagent.
Attributes:
name: The name of the microagent that was triggered
trigger: The word that triggered this microagent
content: The actual content/knowledge from the microagent
"""
name: str
trigger: str
content: str
@dataclass
class RecallObservation(Observation):
"""The retrieval of content from a microagent or more microagents."""
recall_type: RecallType
observation: str = ObservationType.RECALL
# workspace context
repo_name: str = ''
repo_directory: str = ''
repo_instructions: str = ''
runtime_hosts: dict[str, int] = field(default_factory=dict)
additional_agent_instructions: str = ''
# knowledge
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
"""
A list of MicroagentKnowledge objects, each containing information from a triggered microagent.
Example:
[
MicroagentKnowledge(
name="python_best_practices",
trigger="python",
content="Always use virtual environments for Python projects."
),
MicroagentKnowledge(
name="git_workflow",
trigger="git",
content="Create a new branch for each feature or bugfix."
)
]
"""
@property
def message(self) -> str:
return (
'Added workspace context'
if self.recall_type == RecallType.WORKSPACE_CONTEXT
else 'Added microagent knowledge'
)
def __str__(self) -> str:
# Build a string representation
fields = []
if self.recall_type == RecallType.WORKSPACE_CONTEXT:
fields.extend(
[
f'recall_type={self.recall_type}',
f'repo_name={self.repo_name}',
f'repo_instructions={self.repo_instructions[:20]}...',
f'runtime_hosts={self.runtime_hosts}',
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
]
)
else:
fields.extend(
[
f'recall_type={self.recall_type}',
]
)
if self.microagent_knowledge:
fields.extend(
[
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}',
]
)
return f'**RecallObservation**\n{", ".join(fields)}'

View File

@@ -8,6 +8,7 @@ from openhands.events.action.agent import (
AgentRejectAction,
AgentThinkAction,
ChangeAgentStateAction,
RecallAction,
)
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
from openhands.events.action.commands import (
@@ -35,6 +36,7 @@ actions = (
AgentFinishAction,
AgentRejectAction,
AgentDelegateAction,
RecallAction,
ChangeAgentStateAction,
MessageAction,
)

View File

@@ -1,5 +1,6 @@
from dataclasses import asdict
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
@@ -102,6 +103,8 @@ def event_to_dict(event: 'Event') -> dict:
d['timestamp'] = d['timestamp'].isoformat()
if key == 'source' and 'source' in d:
d['source'] = d['source'].value
if key == 'recall_type' and 'recall_type' in d:
d['recall_type'] = d['recall_type'].value
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
if key == 'llm_metrics' and 'llm_metrics' in d:
@@ -119,7 +122,11 @@ def event_to_dict(event: 'Event') -> dict:
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
# such as CmdOutputMetadata
# we serialize it along with the rest
d['extras'] = {k: _convert_pydantic_to_dict(v) for k, v in props.items()}
# we also handle the Enum conversion for RecallObservation
d['extras'] = {
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
for k, v in props.items()
}
# Include success field for CmdOutputObservation
if hasattr(event, 'success'):
d['success'] = event.success

View File

@@ -1,9 +1,12 @@
import copy
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
AgentThinkObservation,
MicroagentKnowledge,
RecallObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
@@ -40,6 +43,7 @@ observations = (
UserRejectObservation,
AgentCondensationObservation,
AgentThinkObservation,
RecallObservation,
)
OBSERVATION_TYPE_TO_CLASS = {
@@ -110,4 +114,18 @@ def observation_from_dict(observation: dict) -> Observation:
else:
extras['metadata'] = CmdOutputMetadata()
if observation_class is RecallObservation:
# handle the Enum conversion
if 'recall_type' in extras:
extras['recall_type'] = RecallType(extras['recall_type'])
# convert dicts in microagent_knowledge to MicroagentKnowledge objects
if 'microagent_knowledge' in extras and isinstance(
extras['microagent_knowledge'], list
):
extras['microagent_knowledge'] = [
MicroagentKnowledge(**item) if isinstance(item, dict) else item
for item in extras['microagent_knowledge']
]
return observation_class(content=content, **extras)

View File

@@ -27,6 +27,7 @@ class EventStreamSubscriber(str, Enum):
RESOLVER = 'openhands_resolver'
SERVER = 'server'
RUNTIME = 'runtime'
MEMORY = 'memory'
MAIN = 'main'
TEST = 'test'

View File

@@ -6,42 +6,44 @@ import httpx
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_types import (
GhAuthenticationError,
GHUnknownException,
GitHubRepository,
GitHubUser,
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
Repository,
SuggestedTask,
TaskType,
UnknownException,
User,
)
from openhands.utils.import_utils import get_impl
class GitHubService:
class GitHubService(GitService):
BASE_URL = 'https://api.github.com'
github_token: SecretStr = SecretStr('')
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,
user_id: str | None = None,
external_auth_id: str | None = None,
external_auth_token: SecretStr | None = None,
github_token: SecretStr | None = None,
token: SecretStr | None = None,
external_token_manager: bool = False,
):
self.user_id = user_id
self.external_token_manager = external_token_manager
if github_token:
self.github_token = github_token
if token:
self.token = token
async def _get_github_headers(self) -> dict:
"""Retrieve the GH Token from settings store to construct the headers."""
if self.user_id and not self.github_token:
self.github_token = await self.get_latest_token()
if self.user_id and not self.token:
self.token = await self.get_latest_token()
return {
'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}',
'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}',
'Accept': 'application/vnd.github.v3+json',
}
@@ -49,7 +51,7 @@ class GitHubService:
return status_code == 401
async def get_latest_token(self) -> SecretStr | None:
return self.github_token
return self.token
async def _fetch_data(
self, url: str, params: dict | None = None
@@ -74,20 +76,20 @@ class GitHubService:
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise GhAuthenticationError('Invalid Github token')
raise AuthenticationError('Invalid Github token')
logger.warning(f'Status error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
except httpx.HTTPError as e:
logger.warning(f'HTTP error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
async def get_user(self) -> GitHubUser:
async def get_user(self) -> User:
url = f'{self.BASE_URL}/user'
response, _ = await self._fetch_data(url)
return GitHubUser(
return User(
id=response.get('id'),
login=response.get('login'),
avatar_url=response.get('avatar_url'),
@@ -98,7 +100,7 @@ class GitHubService:
async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[GitHubRepository]:
) -> list[Repository]:
params = {'page': str(page), 'per_page': str(per_page)}
if installation_id:
url = f'{self.BASE_URL}/user/installations/{installation_id}/repositories'
@@ -111,7 +113,7 @@ class GitHubService:
next_link: str = headers.get('Link', '')
repos = [
GitHubRepository(
Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
@@ -129,7 +131,7 @@ class GitHubService:
async def search_repositories(
self, query: str, per_page: int, sort: str, order: str
) -> list[GitHubRepository]:
) -> list[Repository]:
url = f'{self.BASE_URL}/search/repositories'
params = {'q': query, 'per_page': per_page, 'sort': sort, 'order': order}
@@ -137,7 +139,7 @@ class GitHubService:
repos = response.get('items', [])
repos = [
GitHubRepository(
Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
@@ -163,7 +165,7 @@ class GitHubService:
result = response.json()
if 'errors' in result:
raise GHUnknownException(
raise UnknownException(
f"GraphQL query error: {json.dumps(result['errors'])}"
)
@@ -171,14 +173,14 @@ class GitHubService:
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise GhAuthenticationError('Invalid Github token')
raise AuthenticationError('Invalid Github token')
logger.warning(f'Status error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
except httpx.HTTPError as e:
logger.warning(f'HTTP error on GH API: {e}')
raise GHUnknownException('Unknown error')
raise UnknownException('Unknown error')
async def get_suggested_tasks(self) -> list[SuggestedTask]:
"""Get suggested tasks for the authenticated user across all repositories.

View File

@@ -1,46 +0,0 @@
from enum import Enum
from pydantic import BaseModel
class TaskType(str, Enum):
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
FAILING_CHECKS = 'FAILING_CHECKS'
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
OPEN_ISSUE = 'OPEN_ISSUE'
OPEN_PR = 'OPEN_PR'
class SuggestedTask(BaseModel):
task_type: TaskType
repo: str
issue_number: int
title: str
class GitHubUser(BaseModel):
id: int
login: str
avatar_url: str
company: str | None = None
name: str | None = None
email: str | None = None
class GitHubRepository(BaseModel):
id: int
full_name: str
stargazers_count: int | None = None
link_header: str | None = None
class GhAuthenticationError(ValueError):
"""Raised when there is an issue with GitHub authentication."""
pass
class GHUnknownException(ValueError):
"""Raised when there is an issue with GitHub communcation."""
pass

View File

@@ -0,0 +1,119 @@
import os
from typing import Any
import httpx
from pydantic import SecretStr
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
Repository,
UnknownException,
User,
)
from openhands.utils.import_utils import get_impl
class GitLabService(GitService):
BASE_URL = 'https://gitlab.com/api/v4'
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,
user_id: str | None = None,
external_auth_token: SecretStr | None = None,
token: SecretStr | None = None,
external_token_manager: bool = False,
):
self.user_id = user_id
self.external_token_manager = external_token_manager
if token:
self.token = token
async def _get_gitlab_headers(self) -> dict:
"""
Retrieve the GitLab Token to construct the headers
"""
if self.user_id and not self.token:
self.token = await self.get_latest_token()
return {
'Authorization': f'Bearer {self.token.get_secret_value()}',
}
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401
async def get_latest_token(self) -> SecretStr:
return self.token
async def _fetch_data(
self, url: str, params: dict | None = None
) -> tuple[Any, dict]:
try:
async with httpx.AsyncClient() as client:
gitlab_headers = await self._get_gitlab_headers()
response = await client.get(url, headers=gitlab_headers, params=params)
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
gitlab_headers = await self._get_gitlab_headers()
response = await client.get(
url, headers=gitlab_headers, params=params
)
response.raise_for_status()
headers = {}
if 'Link' in response.headers:
headers['Link'] = response.headers['Link']
return response.json(), headers
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError('Invalid GitLab token')
raise UnknownException('Unknown error')
except httpx.HTTPError:
raise UnknownException('Unknown error')
async def get_user(self) -> User:
url = f'{self.BASE_URL}/user'
response, _ = await self._fetch_data(url)
return User(
id=response.get('id'),
username=response.get('username'),
avatar_url=response.get('avatar_url'),
name=response.get('name'),
email=response.get('email'),
company=response.get('organization'),
login=response.get('username'),
)
async def search_repositories(
self, query: str, per_page: int = 30, sort: str = 'updated', order: str = 'desc'
):
url = f'{self.BASE_URL}/search'
params = {
'scope': 'projects',
'search': query,
'per_page': per_page,
'order_by': sort,
'sort': order,
}
response, headers = await self._fetch_data(url, params)
return response, headers
async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[Repository]:
return []
gitlab_service_cls = os.environ.get(
'OPENHANDS_GITLAB_SERVICE_CLS',
'openhands.integrations.gitlab.gitlab_service.GitLabService',
)
GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls)

View File

@@ -0,0 +1,143 @@
from enum import Enum
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
from pydantic.json import pydantic_encoder
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
Repository,
User,
)
class ProviderType(Enum):
GITHUB = 'github'
GITLAB = 'gitlab'
class ProviderToken(BaseModel):
token: SecretStr | None
user_id: str | None
PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken]
CUSTOM_SECRETS_TYPE = dict[str, SecretStr]
class SecretStore(BaseModel):
provider_tokens: PROVIDER_TOKEN_TYPE = {}
@classmethod
def _convert_token(
cls, token_value: str | ProviderToken | SecretStr
) -> ProviderToken:
if isinstance(token_value, ProviderToken):
return token_value
elif isinstance(token_value, str):
return ProviderToken(token=SecretStr(token_value), user_id=None)
elif isinstance(token_value, SecretStr):
return ProviderToken(token=token_value, user_id=None)
else:
raise ValueError(f'Invalid token type: {type(token_value)}')
def model_post_init(self, __context) -> None:
# Convert any string tokens to ProviderToken objects
converted_tokens = {}
for token_type, token_value in self.provider_tokens.items():
if token_value: # Only convert non-empty tokens
try:
if isinstance(token_type, str):
token_type = ProviderType(token_type)
converted_tokens[token_type] = self._convert_token(token_value)
except ValueError:
# Skip invalid provider types or tokens
continue
self.provider_tokens = converted_tokens
@field_serializer('provider_tokens')
def provider_tokens_serializer(
self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo
):
tokens = {}
expose_secrets = info.context and info.context.get('expose_secrets', False)
for token_type, provider_token in provider_tokens.items():
if not provider_token or not provider_token.token:
continue
token_type_str = (
token_type.value
if isinstance(token_type, ProviderType)
else str(token_type)
)
tokens[token_type_str] = {
'token': provider_token.token.get_secret_value()
if expose_secrets
else pydantic_encoder(provider_token.token),
'user_id': provider_token.user_id,
}
return tokens
class ProviderHandler:
def __init__(
self,
provider_tokens: PROVIDER_TOKEN_TYPE,
external_auth_token: SecretStr | None = None,
):
self.service_class_map: dict[ProviderType, type[GitService]] = {
ProviderType.GITHUB: GithubServiceImpl,
ProviderType.GITLAB: GitLabServiceImpl,
}
self.provider_tokens = provider_tokens
self.external_auth_token = external_auth_token
def _get_service(self, provider: ProviderType) -> GitService:
"""Helper method to instantiate a service for a given provider"""
token = self.provider_tokens[provider]
service_class = self.service_class_map[provider]
return service_class(
user_id=token.user_id,
external_auth_token=self.external_auth_token,
token=token.token,
)
async def get_user(self) -> User:
"""Get user information from the first available provider"""
for provider in self.provider_tokens:
try:
service = self._get_service(provider)
return await service.get_user()
except Exception:
continue
raise AuthenticationError('Need valid provider token')
async def get_latest_provider_tokens(self) -> dict[ProviderType, SecretStr]:
"""Get latest token from services"""
tokens = {}
for provider in self.provider_tokens:
service = self._get_service(provider)
tokens[provider] = await service.get_latest_token()
return tokens
async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[Repository]:
"""Get repositories from all available providers"""
all_repos = []
for provider in self.provider_tokens:
try:
service = self._get_service(provider)
repos = await service.get_repositories(
page, per_page, sort, installation_id
)
all_repos.extend(repos)
except Exception:
continue
return all_repos

View File

@@ -0,0 +1,89 @@
from enum import Enum
from typing import Protocol
from pydantic import BaseModel, SecretStr
class TaskType(str, Enum):
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
FAILING_CHECKS = 'FAILING_CHECKS'
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
OPEN_ISSUE = 'OPEN_ISSUE'
OPEN_PR = 'OPEN_PR'
class SuggestedTask(BaseModel):
task_type: TaskType
repo: str
issue_number: int
title: str
class User(BaseModel):
id: int
login: str
avatar_url: str
company: str | None = None
name: str | None = None
email: str | None = None
class Repository(BaseModel):
id: int
full_name: str
stargazers_count: int | None = None
link_header: str | None = None
class AuthenticationError(ValueError):
"""Raised when there is an issue with GitHub authentication."""
pass
class UnknownException(ValueError):
"""Raised when there is an issue with GitHub communcation."""
pass
class GitService(Protocol):
"""Protocol defining the interface for Git service providers"""
def __init__(
self,
user_id: str | None,
token: SecretStr | None,
external_auth_token: SecretStr | None,
external_token_manager: bool = False,
) -> None:
"""Initialize the service with authentication details"""
...
async def get_latest_token(self) -> SecretStr:
"""Get latest working token of the users"""
...
async def get_user(self) -> User:
"""Get the authenticated user's information"""
...
async def search_repositories(
self,
query: str,
per_page: int,
sort: str,
order: str,
) -> list[Repository]:
"""Search for repositories"""
...
async def get_repositories(
self,
page: int,
per_page: int,
sort: str,
installation_id: int | None,
) -> list[Repository]:
"""Get repositories for the authenticated user"""
...

View File

@@ -0,0 +1,37 @@
from pydantic import SecretStr
from openhands.integrations.github.github_service import GitHubService
from openhands.integrations.gitlab.gitlab_service import GitLabService
from openhands.integrations.provider import ProviderType
async def validate_provider_token(token: SecretStr) -> ProviderType | None:
"""
Determine whether a token is for GitHub or GitLab by attempting to get user info
from both services.
Args:
token: The token to check
Returns:
'github' if it's a GitHub token
'gitlab' if it's a GitLab token
None if the token is invalid for both services
"""
# Try GitHub first
try:
github_service = GitHubService(token=token)
await github_service.get_user()
return ProviderType.GITHUB
except Exception:
pass
# Try GitLab next
try:
gitlab_service = GitLabService(token=token)
await gitlab_service.get_user()
return ProviderType.GITLAB
except Exception:
pass
return None

View File

@@ -249,7 +249,8 @@ class LLM(RetryMixin, DebugMixin):
# if we mocked function calling, and we have tools, convert the response back to function calling format
if mock_function_calling and mock_fncall_tools is not None:
assert len(resp.choices) == 1
logger.debug(f'Response choices: {len(resp.choices)}')
assert len(resp.choices) >= 1
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from openhands.core.config.condenser_config import LLMSummarizingCondenserConfig
from openhands.core.message import Message, TextContent
from openhands.events.event import Event
from openhands.events.observation.agent import AgentCondensationObservation
from openhands.llm import LLM
@@ -90,13 +91,10 @@ INTENT: Fix precision while maintaining FITS compliance"""
for forgotten_event in forgotten_events:
prompt += str(forgotten_event) + '\n\n'
messages = [Message(role='user', content=[TextContent(text=prompt)])]
response = self.llm.completion(
messages=[
{
'content': prompt,
'role': 'user',
},
],
messages=self.llm.format_messages_for_llm(messages),
)
summary = response.choices[0].message.content

View File

@@ -1,5 +1,6 @@
from litellm import ModelResponse
from openhands.core.config.agent_config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.schema import ActionType
@@ -16,7 +17,7 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
)
from openhands.events.event import Event
from openhands.events.event import Event, RecallType
from openhands.events.observation import (
AgentCondensationObservation,
AgentDelegateObservation,
@@ -28,16 +29,21 @@ from openhands.events.observation import (
IPythonRunCellObservation,
UserRejectObservation,
)
from openhands.events.observation.agent import (
MicroagentKnowledge,
RecallObservation,
)
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import truncate_content
from openhands.utils.prompt import PromptManager
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
class ConversationMemory:
"""Processes event history into a coherent conversation for the agent."""
def __init__(self, prompt_manager: PromptManager):
def __init__(self, config: AgentConfig, prompt_manager: PromptManager):
self.agent_config = config
self.prompt_manager = prompt_manager
def process_events(
@@ -46,23 +52,24 @@ class ConversationMemory:
initial_messages: list[Message],
max_message_chars: int | None = None,
vision_is_active: bool = False,
enable_som_visual_browsing: bool = False,
) -> list[Message]:
"""Process state history into a list of messages for the LLM.
Ensures that tool call actions are processed correctly in function calling mode.
Args:
state: The state containing the history of events to convert
condensed_history: The condensed list of events to process
initial_messages: The initial messages to include in the result
condensed_history: The condensed history of events to convert
initial_messages: The initial messages to include in the conversation
max_message_chars: The maximum number of characters in the content of an event included
in the prompt to the LLM. Larger observations are truncated.
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
"""
events = condensed_history
# log visual browsing status
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
# Process special events first (system prompts, etc.)
messages = initial_messages
@@ -70,7 +77,7 @@ class ConversationMemory:
pending_tool_call_action_messages: dict[str, Message] = {}
tool_call_id_to_message: dict[str, Message] = {}
for event in events:
for i, event in enumerate(events):
# create a regular message from an event
if isinstance(event, Action):
messages_to_add = self._process_action(
@@ -84,7 +91,9 @@ class ConversationMemory:
tool_call_id_to_message=tool_call_id_to_message,
max_message_chars=max_message_chars,
vision_is_active=vision_is_active,
enable_som_visual_browsing=enable_som_visual_browsing,
enable_som_visual_browsing=self.agent_config.enable_som_visual_browsing,
current_index=i,
events=events,
)
else:
raise ValueError(f'Unknown event type: {type(event)}')
@@ -270,6 +279,8 @@ class ConversationMemory:
max_message_chars: int | None = None,
vision_is_active: bool = False,
enable_som_visual_browsing: bool = False,
current_index: int = 0,
events: list[Event] | None = None,
) -> list[Message]:
"""Converts an observation into a message format that can be sent to the LLM.
@@ -291,6 +302,8 @@ class ConversationMemory:
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
current_index: The index of the current event in the events list (for deduplication)
events: The list of all events (for deduplication)
Returns:
list[Message]: A list containing the formatted message(s) for the observation.
@@ -372,6 +385,119 @@ class ConversationMemory:
elif isinstance(obs, AgentCondensationObservation):
text = truncate_content(obs.content, max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
elif (
isinstance(obs, RecallObservation)
and self.agent_config.enable_prompt_extensions
):
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
# everything is optional, check if they are present
if obs.repo_name or obs.repo_directory:
repo_info = RepositoryInfo(
repo_name=obs.repo_name or '',
repo_directory=obs.repo_directory or '',
)
else:
repo_info = None
if obs.runtime_hosts or obs.additional_agent_instructions:
runtime_info = RuntimeInfo(
available_hosts=obs.runtime_hosts,
additional_agent_instructions=obs.additional_agent_instructions,
)
else:
runtime_info = None
repo_instructions = (
obs.repo_instructions if obs.repo_instructions else ''
)
# Have some meaningful content before calling the template
has_repo_info = repo_info is not None and (
repo_info.repo_name or repo_info.repo_directory
)
has_runtime_info = runtime_info is not None and (
runtime_info.available_hosts
or runtime_info.additional_agent_instructions
)
has_repo_instructions = bool(repo_instructions.strip())
# Filter and process microagent knowledge
filtered_agents = []
if obs.microagent_knowledge:
# Exclude disabled microagents
filtered_agents = [
agent
for agent in obs.microagent_knowledge
if agent.name not in self.agent_config.disabled_microagents
]
has_microagent_knowledge = bool(filtered_agents)
# Generate appropriate content based on what is present
message_content = []
# Build the workspace context information
if has_repo_info or has_runtime_info or has_repo_instructions:
formatted_workspace_text = (
self.prompt_manager.build_workspace_context(
repository_info=repo_info,
runtime_info=runtime_info,
repo_instructions=repo_instructions,
)
)
message_content.append(TextContent(text=formatted_workspace_text))
# Add microagent knowledge if present
if has_microagent_knowledge:
formatted_microagent_text = (
self.prompt_manager.build_microagent_info(
triggered_agents=filtered_agents,
)
)
message_content.append(TextContent(text=formatted_microagent_text))
# Return the combined message if we have any content
if message_content:
message = Message(role='user', content=message_content)
else:
return []
elif obs.recall_type == RecallType.KNOWLEDGE:
# Use prompt manager to build the microagent info
# First, filter out agents that appear in earlier RecallObservations
filtered_agents = self._filter_agents_in_microagent_obs(
obs, current_index, events or []
)
# Create and return a message if there is microagent knowledge to include
if filtered_agents:
# Exclude disabled microagents
filtered_agents = [
agent
for agent in filtered_agents
if agent.name not in self.agent_config.disabled_microagents
]
# Only proceed if we still have agents after filtering out disabled ones
if filtered_agents:
formatted_text = self.prompt_manager.build_microagent_info(
triggered_agents=filtered_agents,
)
return [
Message(
role='user', content=[TextContent(text=formatted_text)]
)
]
# Return empty list if no microagents to include or all were disabled
return []
elif (
isinstance(obs, RecallObservation)
and not self.agent_config.enable_prompt_extensions
):
# If prompt extensions are disabled, we don't add any additional info
# TODO: test this
return []
else:
# If an observation message is not returned, it will cause an error
# when the LLM tries to return the next message
@@ -404,3 +530,51 @@ class ConversationMemory:
-1
].cache_prompt = True # Last item inside the message content
break
def _filter_agents_in_microagent_obs(
self, obs: RecallObservation, current_index: int, events: list[Event]
) -> list[MicroagentKnowledge]:
"""Filter out agents that appear in earlier RecallObservations.
Args:
obs: The current RecallObservation to filter
current_index: The index of the current event in the events list
events: The list of all events
Returns:
list[MicroagentKnowledge]: The filtered list of microagent knowledge
"""
if obs.recall_type != RecallType.KNOWLEDGE:
return obs.microagent_knowledge
# For each agent in the current microagent observation, check if it appears in any earlier microagent observation
filtered_agents = []
for agent in obs.microagent_knowledge:
# Keep this agent if it doesn't appear in any earlier observation
# that is, if this is the first microagent observation with this microagent
if not self._has_agent_in_earlier_events(agent.name, current_index, events):
filtered_agents.append(agent)
return filtered_agents
def _has_agent_in_earlier_events(
self, agent_name: str, current_index: int, events: list[Event]
) -> bool:
"""Check if an agent appears in any earlier RecallObservation in the event list.
Args:
agent_name: The name of the agent to look for
current_index: The index of the current event in the events list
events: The list of all events
Returns:
bool: True if the agent appears in an earlier RecallObservation, False otherwise
"""
for event in events[:current_index]:
# Note that this check includes the WORKSPACE_CONTEXT
if isinstance(event, RecallObservation):
if any(
agent.name == agent_name for agent in event.microagent_knowledge
):
return True
return False

292
openhands/memory/memory.py Normal file
View File

@@ -0,0 +1,292 @@
import asyncio
import os
import uuid
from typing import Callable
import openhands
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.observation.agent import (
MicroagentKnowledge,
RecallObservation,
)
from openhands.events.observation.empty import NullObservation
from openhands.events.stream import EventStream, EventStreamSubscriber
from openhands.microagent import (
BaseMicroAgent,
KnowledgeMicroAgent,
RepoMicroAgent,
load_microagents_from_dir,
)
from openhands.runtime.base import Runtime
from openhands.utils.prompt import RepositoryInfo, RuntimeInfo
GLOBAL_MICROAGENTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'microagents',
)
class Memory:
"""
Memory is a component that listens to the EventStream for information retrieval actions
(a RecallAction) and publishes observations with the content (such as RecallObservation).
"""
sid: str
event_stream: EventStream
status_callback: Callable | None
loop: asyncio.AbstractEventLoop | None
def __init__(
self,
event_stream: EventStream,
sid: str,
status_callback: Callable | None = None,
):
self.event_stream = event_stream
self.sid = sid if sid else str(uuid.uuid4())
self.status_callback = status_callback
self.loop = None
self.event_stream.subscribe(
EventStreamSubscriber.MEMORY,
self.on_event,
self.sid,
)
# Additional placeholders to store user workspace microagents
self.repo_microagents: dict[str, RepoMicroAgent] = {}
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
# Store repository / runtime info to send them to the templating later
self.repository_info: RepositoryInfo | None = None
self.runtime_info: RuntimeInfo | None = None
# Load global microagents (Knowledge + Repo)
# from typically OpenHands/microagents (i.e., the PUBLIC microagents)
self._load_global_microagents()
def on_event(self, event: Event):
"""Handle an event from the event stream."""
asyncio.get_event_loop().run_until_complete(self._on_event(event))
async def _on_event(self, event: Event):
"""Handle an event from the event stream asynchronously."""
try:
if isinstance(event, RecallAction):
# if this is a workspace context recall (on first user message)
# create and add a RecallObservation
# with info about repo, runtime, instructions, etc. including microagent knowledge if any
if (
event.source == EventSource.USER
and event.recall_type == RecallType.WORKSPACE_CONTEXT
):
logger.debug('Workspace context recall')
workspace_obs: RecallObservation | NullObservation | None = None
workspace_obs = self._on_workspace_context_recall(event)
if workspace_obs is None:
workspace_obs = NullObservation(content='')
# important: this will release the execution flow from waiting for the retrieval to complete
workspace_obs._cause = event.id # type: ignore[union-attr]
self.event_stream.add_event(workspace_obs, EventSource.ENVIRONMENT)
return
# Handle knowledge recall (triggered microagents)
elif (
event.source == EventSource.USER
and event.recall_type == RecallType.KNOWLEDGE
):
logger.debug('Microagent knowledge recall')
microagent_obs: RecallObservation | NullObservation | None = None
microagent_obs = self._on_microagent_recall(event)
if microagent_obs is None:
microagent_obs = NullObservation(content='')
# important: this will release the execution flow from waiting for the retrieval to complete
microagent_obs._cause = event.id # type: ignore[union-attr]
self.event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
return
except Exception as e:
error_str = f'Error: {str(e.__class__.__name__)}'
logger.error(error_str)
self.send_error_message('STATUS$ERROR_MEMORY', error_str)
return
def _on_workspace_context_recall(
self, event: RecallAction
) -> RecallObservation | None:
"""Add repository and runtime information to the stream as a RecallObservation."""
# Create WORKSPACE_CONTEXT info:
# - repository_info
# - runtime_info
# - repository_instructions
# - microagent_knowledge
# Collect raw repository instructions
repo_instructions = ''
assert (
len(self.repo_microagents) <= 1
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
# Retrieve the context of repo instructions
for microagent in self.repo_microagents.values():
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content
# Find any matched microagents based on the query
microagent_knowledge = self._find_microagent_knowledge(event.query)
# Create observation if we have anything
if (
self.repository_info
or self.runtime_info
or repo_instructions
or microagent_knowledge
):
obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name=self.repository_info.repo_name
if self.repository_info and self.repository_info.repo_name is not None
else '',
repo_directory=self.repository_info.repo_directory
if self.repository_info
and self.repository_info.repo_directory is not None
else '',
repo_instructions=repo_instructions if repo_instructions else '',
runtime_hosts=self.runtime_info.available_hosts
if self.runtime_info and self.runtime_info.available_hosts is not None
else {},
additional_agent_instructions=self.runtime_info.additional_agent_instructions
if self.runtime_info
and self.runtime_info.additional_agent_instructions is not None
else '',
microagent_knowledge=microagent_knowledge,
content='Added workspace context',
)
return obs
return None
def _on_microagent_recall(
self,
event: RecallAction,
) -> RecallObservation | None:
"""When a microagent action triggers microagents, create a RecallObservation with structured data."""
# Find any matched microagents based on the query
microagent_knowledge = self._find_microagent_knowledge(event.query)
# Create observation if we have anything
if microagent_knowledge:
obs = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=microagent_knowledge,
content='Retrieved knowledge from microagents',
)
return obs
return None
def _find_microagent_knowledge(self, query: str) -> list[MicroagentKnowledge]:
"""Find microagent knowledge based on a query.
Args:
query: The query to search for microagent triggers
Returns:
A list of MicroagentKnowledge objects for matched triggers
"""
recalled_content: list[MicroagentKnowledge] = []
# skip empty queries
if not query:
return recalled_content
# Search for microagent triggers in the query
for name, microagent in self.knowledge_microagents.items():
trigger = microagent.match_trigger(query)
if trigger:
logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger)
recalled_content.append(
MicroagentKnowledge(
name=microagent.name,
trigger=trigger,
content=microagent.content,
)
)
return recalled_content
def load_user_workspace_microagents(
self, user_microagents: list[BaseMicroAgent]
) -> None:
"""
This method loads microagents from a user's cloned repo or workspace directory.
This is typically called from agent_session or setup once the workspace is cloned.
"""
logger.info(
'Loading user workspace microagents: %s', [m.name for m in user_microagents]
)
for user_microagent in user_microagents:
if isinstance(user_microagent, KnowledgeMicroAgent):
self.knowledge_microagents[user_microagent.name] = user_microagent
elif isinstance(user_microagent, RepoMicroAgent):
self.repo_microagents[user_microagent.name] = user_microagent
def _load_global_microagents(self) -> None:
"""
Loads microagents from the global microagents_dir
"""
repo_agents, knowledge_agents, _ = load_microagents_from_dir(
GLOBAL_MICROAGENTS_DIR
)
for name, agent in knowledge_agents.items():
if isinstance(agent, KnowledgeMicroAgent):
self.knowledge_microagents[name] = agent
for name, agent in repo_agents.items():
if isinstance(agent, RepoMicroAgent):
self.repo_microagents[name] = agent
def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
"""Store repository info so we can reference it in an observation."""
if repo_name or repo_directory:
self.repository_info = RepositoryInfo(repo_name, repo_directory)
else:
self.repository_info = None
def set_runtime_info(self, runtime: Runtime) -> None:
"""Store runtime info (web hosts, ports, etc.)."""
# e.g. { '127.0.0.1': 8080 }
if runtime.web_hosts or runtime.additional_agent_instructions:
self.runtime_info = RuntimeInfo(
available_hosts=runtime.web_hosts,
additional_agent_instructions=runtime.additional_agent_instructions,
)
else:
self.runtime_info = None
def send_error_message(self, message_id: str, message: str):
"""Sends an error message if the callback function was provided."""
if self.status_callback:
try:
if self.loop is None:
self.loop = asyncio.get_running_loop()
asyncio.run_coroutine_threadsafe(
self._send_status_message('error', message_id, message), self.loop
)
except RuntimeError as e:
logger.error(
f'Error sending status message: {e.__class__.__name__}',
stack_info=False,
)
async def _send_status_message(self, msg_type: str, id: str, message: str):
"""Sends a status message to the client."""
if self.status_callback:
self.status_callback(msg_type, id, message)

View File

@@ -27,7 +27,7 @@ class GithubIssueHandler(IssueHandlerInterface):
def get_headers(self) -> dict[str, str]:
return {
'Authorization': f'token {self.token}',
'Authorization': f'Bearer {self.token}',
'Accept': 'application/vnd.github.v3+json',
}
@@ -450,7 +450,7 @@ class GithubPRHandler(GithubIssueHandler):
"""Download comments for a specific pull request from Github."""
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
headers = {
'Authorization': f'token {self.token}',
'Authorization': f'Bearer {self.token}',
'Accept': 'application/vnd.github.v3+json',
}
params = {'per_page': 100, 'page': 1}

View File

@@ -158,7 +158,7 @@ class ActionExecutor:
self.bash_session: BashSession | None = None
self.lock = asyncio.Lock()
self.plugins: dict[str, Plugin] = {}
self.file_editor = OHEditor()
self.file_editor = OHEditor(workspace_root=self._initial_cwd)
self.browser = BrowserEnv(browsergym_eval_env)
self.start_time = time.time()
self.last_execution_time = self.start_time

View File

@@ -97,7 +97,7 @@ class Runtime(FileEditRuntimeMixin):
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = False,
github_user_id: str | None = None,
user_id: str | None = None,
):
self.sid = sid
self.event_stream = event_stream
@@ -130,7 +130,7 @@ class Runtime(FileEditRuntimeMixin):
self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor
)
self.github_user_id = github_user_id
self.user_id = user_id
def setup_initial_env(self) -> None:
if self.attach_to_existing:
@@ -220,9 +220,9 @@ class Runtime(FileEditRuntimeMixin):
assert event.timeout is not None
try:
if isinstance(event, CmdRunAction):
if self.github_user_id and '$GITHUB_TOKEN' in event.command:
if self.user_id and '$GITHUB_TOKEN' in event.command:
gh_client = GithubServiceImpl(
user_id=self.github_user_id, external_token_manager=True
external_auth_id=self.user_id, external_token_manager=True
)
token = await gh_client.get_latest_token()
if token:

View File

@@ -59,7 +59,7 @@ class ActionExecutionClient(Runtime):
status_callback: Any | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
github_user_id: str | None = None,
user_id: str | None = None,
):
self.session = HttpSession()
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
@@ -75,7 +75,7 @@ class ActionExecutionClient(Runtime):
status_callback,
attach_to_existing,
headless_mode,
github_user_id,
user_id,
)
@abstractmethod

View File

@@ -1,3 +1,4 @@
import logging
import os
from typing import Callable
from urllib.parse import urlparse
@@ -45,7 +46,7 @@ class RemoteRuntime(ActionExecutionClient):
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
github_user_id: str | None = None,
user_id: str | None = None,
):
super().__init__(
config,
@@ -56,7 +57,7 @@ class RemoteRuntime(ActionExecutionClient):
status_callback,
attach_to_existing,
headless_mode,
github_user_id,
user_id,
)
if self.config.sandbox.api_key is None:
raise ValueError(
@@ -425,10 +426,11 @@ class RemoteRuntime(ActionExecutionClient):
return self._send_action_server_request_impl(method, url, **kwargs)
retry_decorator = tenacity.retry(
retry=tenacity.retry_if_exception_type(ConnectionError),
retry=tenacity.retry_if_exception_type(requests.ConnectionError),
stop=tenacity.stop_after_attempt(3)
| stop_if_should_exit()
| self._stop_if_closed,
before_sleep=tenacity.before_sleep_log(logger, logging.WARNING),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
)
return retry_decorator(self._send_action_server_request_impl)(

View File

@@ -1,6 +1,13 @@
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
"""Get GitHub token from request state. For backward compatibility."""
return getattr(request.state, 'provider_tokens', None)
def get_access_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'access_token', None)
@@ -11,8 +18,18 @@ def get_user_id(request: Request) -> str | None:
def get_github_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'github_token', None)
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].token
return None
def get_github_user_id(request: Request) -> str | None:
return getattr(request.state, 'github_user_id', None)
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].user_id
return None

View File

@@ -46,7 +46,12 @@ class ConversationManager(ABC):
@abstractmethod
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
self,
sid: str,
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
) -> EventStream | None:
"""Join a conversation and return its event stream."""
@@ -74,6 +79,7 @@ class ConversationManager(ABC):
settings: Settings,
user_id: str | None,
initial_user_msg: MessageAction | None = None,
github_user_id: str | None = None,
) -> EventStream:
"""Start an event loop if one is not already running"""

View File

@@ -106,7 +106,12 @@ class StandaloneConversationManager(ConversationManager):
return c
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
self,
sid: str,
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
):
logger.info(
f'join_conversation:{sid}:{connection_id}',
@@ -116,7 +121,9 @@ class StandaloneConversationManager(ConversationManager):
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
if not event_stream:
return await self.maybe_start_agent_loop(sid, settings, user_id)
return await self.maybe_start_agent_loop(
sid, settings, user_id, github_user_id=github_user_id
)
for event in event_stream.get_events(reverse=True):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in (
@@ -187,14 +194,18 @@ class StandaloneConversationManager(ConversationManager):
logger.error('error_cleaning_stale')
await asyncio.sleep(_CLEANUP_INTERVAL)
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
async def _get_conversation_store(
self, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
conversation_store_class = self._conversation_store_class
if not conversation_store_class:
self._conversation_store_class = conversation_store_class = get_impl(
ConversationStore, # type: ignore
self.server_config.conversation_store_class,
)
store = await conversation_store_class.get_instance(self.config, user_id)
store = await conversation_store_class.get_instance(
self.config, user_id, github_user_id
)
return store
async def get_running_agent_loops(
@@ -243,6 +254,7 @@ class StandaloneConversationManager(ConversationManager):
settings: Settings,
user_id: str | None,
initial_user_msg: MessageAction | None = None,
github_user_id: str | None = None,
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
session: Session | None = None
@@ -256,7 +268,9 @@ class StandaloneConversationManager(ConversationManager):
extra={'session_id': sid, 'user_id': user_id},
)
# Get the conversations sorted (oldest first)
conversation_store = await self._get_conversation_store(user_id)
conversation_store = await self._get_conversation_store(
user_id, github_user_id
)
conversations = await conversation_store.get_all_metadata(response_ids)
conversations.sort(key=_last_updated_at_key, reverse=True)
@@ -277,7 +291,9 @@ class StandaloneConversationManager(ConversationManager):
try:
session.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER,
self._create_conversation_update_callback(user_id, sid),
self._create_conversation_update_callback(
user_id, github_user_id, sid
),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
@@ -374,22 +390,23 @@ class StandaloneConversationManager(ConversationManager):
)
def _create_conversation_update_callback(
self, user_id: str | None, conversation_id: str
self, user_id: str | None, github_user_id: str | None, conversation_id: str
) -> Callable:
def callback(*args, **kwargs):
call_async_from_sync(
self._update_timestamp_for_conversation,
GENERAL_TIMEOUT,
user_id,
github_user_id,
conversation_id,
)
return callback
async def _update_timestamp_for_conversation(
self, user_id: str, conversation_id: str
self, user_id: str, github_user_id: str, conversation_id: str
):
conversation_store = await self._get_conversation_store(user_id)
conversation_store = await self._get_conversation_store(user_id, github_user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now(timezone.utc)
await conversation_store.save_metadata(conversation)

View File

@@ -6,10 +6,14 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
NullAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.observation import (
NullObservation,
)
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.agent import (
AgentStateChangedObservation,
RecallObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.shared import (
@@ -35,7 +39,9 @@ async def connect(connection_id: str, environ):
cookies_str = environ.get('HTTP_COOKIE', '')
conversation_validator = ConversationValidatorImpl()
user_id = await conversation_validator.validate(conversation_id, cookies_str)
user_id, github_user_id = await conversation_validator.validate(
conversation_id, cookies_str
)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
@@ -46,7 +52,7 @@ async def connect(connection_id: str, environ):
)
event_stream = await conversation_manager.join_conversation(
conversation_id, connection_id, settings, user_id
conversation_id, connection_id, settings, user_id, github_user_id
)
agent_state_changed = None
@@ -54,10 +60,7 @@ async def connect(connection_id: str, environ):
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
),
(NullAction, NullObservation, RecallAction, RecallObservation),
):
continue
elif isinstance(event, AgentStateChangedObservation):

View File

@@ -194,10 +194,14 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
settings = await settings_store.load()
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
if getattr(request.state, 'github_token', None) is None:
if settings and settings.github_token:
request.state.github_token = settings.github_token
if getattr(request.state, 'provider_tokens', None) is None:
if (
settings
and settings.secrets_store
and settings.secrets_store.provider_tokens
):
request.state.provider_tokens = settings.secrets_store.provider_tokens
else:
request.state.github_token = None
request.state.provider_tokens = None
return await call_next(request)

View File

@@ -3,147 +3,168 @@ from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.github.github_types import (
GhAuthenticationError,
GHUnknownException,
GitHubRepository,
GitHubUser,
SuggestedTask,
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
ProviderType,
)
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
from openhands.integrations.service_types import (
AuthenticationError,
Repository,
SuggestedTask,
UnknownException,
User,
)
from openhands.server.auth import get_access_token, get_provider_tokens
app = APIRouter(prefix='/api/github')
@app.get('/repositories', response_model=list[GitHubRepository])
@app.get('/repositories', response_model=list[Repository])
async def get_github_repositories(
page: int = 1,
per_page: int = 10,
sort: str = 'pushed',
installation_id: int | None = None,
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
repos: list[Repository] = await client.get_repositories(
page, per_page, sort, installation_id
)
return repos
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
repos: list[GitHubRepository] = await client.get_repositories(
page, per_page, sort, installation_id
)
return repos
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/user', response_model=GitHubUser)
@app.get('/user', response_model=User)
async def get_github_user(
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens:
client = ProviderHandler(provider_tokens=provider_tokens, external_auth_token=access_token)
try:
user: User = await client.get_user()
return user
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
user: GitHubUser = await client.get_user()
return user
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/installations', response_model=list[int])
async def get_github_installation_ids(
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
installations_ids: list[int] = await client.get_installation_ids()
return installations_ids
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
installations_ids: list[int] = await client.get_installation_ids()
return installations_ids
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/search/repositories', response_model=list[GitHubRepository])
@app.get('/search/repositories', response_model=list[Repository])
async def search_github_repositories(
query: str,
per_page: int = 5,
sort: str = 'stars',
order: str = 'desc',
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token),
):
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
repos: list[Repository] = await client.search_repositories(
query, per_page, sort, order
)
return repos
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
repos: list[GitHubRepository] = await client.search_repositories(
query, per_page, sort, order
)
return repos
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_401_UNAUTHORIZED,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@app.get('/suggested-tasks', response_model=list[SuggestedTask])
async def get_suggested_tasks(
github_user_id: str | None = Depends(get_github_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
access_token: SecretStr | None = Depends(get_access_token),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
access_token: SecretStr | None = Depends(get_access_token)
):
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
@@ -151,23 +172,30 @@ async def get_suggested_tasks(
- PRs owned by the user
- Issues assigned to the user.
"""
client = GithubServiceImpl(
user_id=github_user_id,
external_auth_token=access_token,
github_token=github_user_token,
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
client = GithubServiceImpl(
user_id=token.user_id, external_auth_token=access_token, token=token.token
)
try:
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
return tasks
except AuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=401,
)
except UnknownException as e:
return JSONResponse(
content=str(e),
status_code=500,
)
return JSONResponse(
content='GitHub token required.',
status_code=status.HTTP_401_UNAUTHORIZED,
)
try:
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
return tasks
except GhAuthenticationError as e:
return JSONResponse(
content=str(e),
status_code=401,
)
except GHUnknownException as e:
return JSONResponse(
content=str(e),
status_code=500,
)

View File

@@ -8,8 +8,14 @@ from pydantic import BaseModel, SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.provider import ProviderType
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
from openhands.server.auth import (
get_access_token,
get_github_user_id,
get_provider_tokens,
get_user_id,
)
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@@ -72,12 +78,12 @@ async def _create_new_conversation(
logger.warn('Settings not present, not starting conversation')
raise MissingSettingsError('Settings not found')
session_init_args['github_token'] = token or SecretStr('')
session_init_args['provider_token'] = token
session_init_args['selected_repository'] = selected_repository
session_init_args['selected_branch'] = selected_branch
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
@@ -99,7 +105,8 @@ async def _create_new_conversation(
ConversationMetadata(
conversation_id=conversation_id,
title=conversation_title,
github_user_id=user_id,
user_id=user_id,
github_user_id=None,
selected_repository=selected_repository,
selected_branch=selected_branch,
)
@@ -121,7 +128,10 @@ async def _create_new_conversation(
image_urls=image_urls or [],
)
await conversation_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data, user_id, initial_message_action
conversation_id,
conversation_init_data,
user_id,
initial_user_msg=initial_message_action,
)
logger.info(f'Finished initializing conversation {conversation_id}')
@@ -136,13 +146,18 @@ async def new_conversation(request: Request, data: InitSessionRequest):
using the returned conversation ID.
"""
logger.info('Initializing new conversation')
user_id = get_github_user_id(request)
gh_client = GithubServiceImpl(
user_id=user_id,
external_auth_token=get_access_token(request),
github_token=get_github_token(request),
)
github_token = await gh_client.get_latest_token()
user_id = None
github_token = None
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
token = provider_tokens[ProviderType.GITHUB]
user_id = token.user_id
gh_client = GithubServiceImpl(
user_id=user_id,
external_auth_token=get_access_token(request),
token=token.token,
)
github_token = await gh_client.get_latest_token()
selected_repository = data.selected_repository
selected_branch = data.selected_branch
@@ -152,7 +167,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id,
get_user_id(request),
github_token,
selected_repository,
selected_branch,
@@ -191,7 +206,7 @@ async def search_conversations(
limit: int = 20,
) -> ConversationInfoResultSet:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
@@ -210,7 +225,7 @@ async def search_conversations(
conversation.conversation_id for conversation in filtered_results
)
running_conversations = await conversation_manager.get_running_agent_loops(
get_github_user_id(request), set(conversation_ids)
get_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(
@@ -230,7 +245,7 @@ async def get_conversation(
conversation_id: str, request: Request
) -> ConversationInfo | None:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
@@ -246,7 +261,7 @@ async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
@@ -262,7 +277,7 @@ async def delete_conversation(
request: Request,
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
try:
await conversation_store.get_metadata(conversation_id)

View File

@@ -3,8 +3,9 @@ from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.server.auth import get_github_token, get_user_id
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.utils import validate_provider_token
from openhands.server.auth import get_provider_tokens, get_user_id
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
from openhands.server.shared import SettingsStoreImpl, config
@@ -23,14 +24,14 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
content={'error': 'Settings not found'},
)
token_is_set = bool(user_id) or bool(get_github_token(request))
github_token_is_set = bool(user_id) or bool(get_provider_tokens(request))
settings_with_token_data = GETSettingsModel(
**settings.model_dump(),
github_token_is_set=token_is_set,
github_token_is_set=github_token_is_set,
)
settings_with_token_data.llm_api_key = settings.llm_api_key
del settings_with_token_data.github_token
del settings_with_token_data.secrets_store
return settings_with_token_data
except Exception as e:
logger.warning(f'Invalid token: {e}')
@@ -45,26 +46,27 @@ async def store_settings(
request: Request,
settings: POSTSettingsModel,
) -> JSONResponse:
# Check if token is valid
if settings.github_token:
try:
# We check if the token is valid by getting the user
# If the token is invalid, this will raise an exception
github = GithubServiceImpl(
user_id=None,
external_auth_token=None,
github_token=SecretStr(settings.github_token),
)
await github.get_user()
# Check provider tokens are valid
if settings.provider_tokens:
# Remove extraneous token types
provider_types = [provider.value for provider in ProviderType]
settings.provider_tokens = {
k: v for k, v in settings.provider_tokens.items() if k in provider_types
}
except Exception as e:
logger.warning(f'Invalid GitHub token: {e}')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': 'Invalid GitHub token. Please make sure it is valid.'
},
)
# Determine whether tokens are valid
for token_type, token_value in settings.provider_tokens.items():
if token_value:
confirmed_token_type = await validate_provider_token(
SecretStr(token_value)
)
if not confirmed_token_type or confirmed_token_type.value != token_type:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': f'Invalid token. Please make sure it is a valid {token_type} token.'
},
)
try:
settings_store = await SettingsStoreImpl.get_instance(
@@ -72,32 +74,54 @@ async def store_settings(
)
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
if existing_settings:
# LLM key isn't on the frontend, so we need to keep it if unset
# Keep existing LLM settings if not provided
if settings.llm_api_key is None:
settings.llm_api_key = existing_settings.llm_api_key
if settings.llm_model is None:
settings.llm_model = existing_settings.llm_model
if settings.llm_base_url is None:
settings.llm_base_url = existing_settings.llm_base_url
if settings.github_token is None:
settings.github_token = existing_settings.github_token
# Keep existing analytics consent if not provided
if settings.user_consents_to_analytics is None:
settings.user_consents_to_analytics = (
existing_settings.user_consents_to_analytics
)
if settings.llm_model is None:
settings.llm_model = existing_settings.llm_model
if settings.unset_github_token:
settings.secrets_store.provider_tokens = {}
settings.provider_tokens = {}
else: # Only merge if not unsetting tokens
if settings.provider_tokens:
if existing_settings.secrets_store:
existing_providers = [
provider.value
for provider in existing_settings.secrets_store.provider_tokens
]
if settings.llm_base_url is None:
settings.llm_base_url = existing_settings.llm_base_url
response = JSONResponse(
status_code=status.HTTP_200_OK,
content={'message': 'Settings stored'},
)
if settings.unset_github_token:
settings.github_token = None
# Merge incoming settings store with the existing one
for provider, token_value in settings.provider_tokens.items():
if provider in existing_providers and not token_value:
provider_type = ProviderType(provider)
existing_token = (
existing_settings.secrets_store.provider_tokens.get(
provider_type
)
)
if existing_token and existing_token.token:
settings.provider_tokens[provider] = (
existing_token.token.get_secret_value()
)
else: # nothing passed in means keep current settings
provider_tokens = existing_settings.secrets_store.provider_tokens
settings.provider_tokens = {
provider.value: data.token.get_secret_value()
if data.token
else None
for provider, data in provider_tokens.items()
}
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:
@@ -106,9 +130,11 @@ async def store_settings(
)
settings = convert_to_settings(settings)
await settings_store.store(settings)
return response
return JSONResponse(
status_code=status.HTTP_200_OK,
content={'message': 'Settings stored'},
)
except Exception as e:
logger.warning(f'Something went wrong storing settings: {e}')
return JSONResponse(
@@ -127,8 +153,19 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
if key in Settings.model_fields # Ensures only `Settings` fields are included
}
# Convert the `llm_api_key` and `github_token` to a `SecretStr` instance
# Convert the `llm_api_key` to a `SecretStr` instance
filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key
filtered_settings_data['github_token'] = settings_with_token_data.github_token
return Settings(**filtered_settings_data)
# Create a new Settings instance without provider tokens
settings = Settings(**filtered_settings_data)
# Update provider tokens if any are provided
if settings_with_token_data.provider_tokens:
for token_type, token_value in settings_with_token_data.provider_tokens.items():
if token_value:
provider = ProviderType(token_type)
settings.secrets_store.provider_tokens[provider] = ProviderToken(
token=SecretStr(token_value), user_id=None
)
return settings

View File

@@ -15,7 +15,8 @@ from openhands.core.schema.agent import AgentState
from openhands.events.action import ChangeAgentStateAction, MessageAction
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.microagent import BaseMicroAgent
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroAgent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
@@ -52,7 +53,7 @@ class AgentSession:
sid: str,
file_store: FileStore,
status_callback: Callable | None = None,
github_user_id: str | None = None,
user_id: str | None = None,
):
"""Initializes a new instance of the Session class
@@ -65,9 +66,9 @@ class AgentSession:
self.event_stream = EventStream(sid, file_store)
self.file_store = file_store
self._status_callback = status_callback
self.github_user_id = github_user_id
self.user_id = user_id
self.logger = OpenHandsLoggerAdapter(
extra={'session_id': sid, 'user_id': github_user_id}
extra={'session_id': sid, 'user_id': user_id}
)
async def start(
@@ -126,6 +127,15 @@ class AgentSession:
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
)
repo_directory = None
if self.runtime and runtime_connected and selected_repository:
repo_directory = selected_repository.split('/')[-1]
self.memory = await self._create_memory(
selected_repository=selected_repository,
repo_directory=repo_directory,
)
if github_token:
self.event_stream.set_secrets(
{
@@ -231,7 +241,7 @@ class AgentSession:
kwargs = {}
if runtime_cls == RemoteRuntime:
kwargs['github_user_id'] = self.github_user_id
kwargs['user_id'] = self.user_id
self.runtime = runtime_cls(
config=config,
@@ -260,26 +270,14 @@ class AgentSession:
)
return False
repo_directory = None
if selected_repository:
repo_directory = await call_sync_from_async(
await call_sync_from_async(
self.runtime.clone_repo,
github_token,
selected_repository,
selected_branch,
)
if agent.prompt_manager:
agent.prompt_manager.set_runtime_info(self.runtime)
microagents: list[BaseMicroAgent] = await call_sync_from_async(
self.runtime.get_microagents_from_selected_repo, selected_repository
)
agent.prompt_manager.load_microagents(microagents)
if selected_repository and repo_directory:
agent.prompt_manager.set_repository_info(
selected_repository, repo_directory
)
self.logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
)
@@ -342,6 +340,29 @@ class AgentSession:
return controller
async def _create_memory(
self, selected_repository: str | None, repo_directory: str | None
) -> Memory:
memory = Memory(
event_stream=self.event_stream,
sid=self.sid,
status_callback=self._status_callback,
)
if self.runtime:
# sets available hosts and other runtime info
memory.set_runtime_info(self.runtime)
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroAgent] = await call_sync_from_async(
self.runtime.get_microagents_from_selected_repo, selected_repository
)
memory.load_user_workspace_microagents(microagents)
if selected_repository and repo_directory:
memory.set_repository_info(selected_repository, repo_directory)
return memory
def _maybe_restore_state(self) -> State | None:
"""Helper method to handle state restore logic."""
restored_state = None

View File

@@ -8,6 +8,6 @@ class ConversationInitData(Settings):
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
github_token: SecretStr | None = Field(default=None)
provider_token: SecretStr | None = Field(default=None)
selected_repository: str | None = Field(default=None)
selected_branch: str | None = Field(default=None)

View File

@@ -61,7 +61,7 @@ class Session:
sid,
file_store,
status_callback=self.queue_status_message,
github_user_id=user_id,
user_id=user_id,
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
@@ -123,11 +123,11 @@ class Session:
agent = Agent.get_cls(agent_cls)(llm, agent_config)
github_token = None
provider_token = None
selected_repository = None
selected_branch = None
if isinstance(settings, ConversationInitData):
github_token = settings.github_token
provider_token = settings.provider_token
selected_repository = settings.selected_repository
selected_branch = settings.selected_branch
@@ -140,7 +140,7 @@ class Session:
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
github_token=provider_token,
selected_repository=selected_repository,
selected_branch=selected_branch,
initial_message=initial_message,

View File

@@ -1,10 +1,17 @@
from __future__ import annotations
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
from pydantic import (
BaseModel,
SecretStr,
SerializationInfo,
field_serializer,
model_validator,
)
from pydantic.json import pydantic_encoder
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.utils import load_app_config
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
class Settings(BaseModel):
@@ -21,7 +28,7 @@ class Settings(BaseModel):
llm_api_key: SecretStr | None = None
llm_base_url: str | None = None
remote_runtime_resource_factor: int | None = None
github_token: SecretStr | None = None
secrets_store: SecretStore = SecretStore()
enable_default_condenser: bool = False
enable_sound_notifications: bool = False
user_consents_to_analytics: bool | None = None
@@ -36,24 +43,65 @@ class Settings(BaseModel):
if context and context.get('expose_secrets', False):
return llm_api_key.get_secret_value()
return pydantic_encoder(llm_api_key)
return pydantic_encoder(llm_api_key) if llm_api_key else None
@field_serializer('github_token')
def github_token_serializer(
self, github_token: SecretStr | None, info: SerializationInfo
):
"""Custom serializer for the GitHub token.
@staticmethod
def _convert_token_value(
token_type: ProviderType, token_value: str | dict
) -> ProviderToken | None:
"""Convert a token value to a ProviderToken object."""
if isinstance(token_value, dict):
token_str = token_value.get('token')
if not token_str:
return None
return ProviderToken(
token=SecretStr(token_str),
user_id=token_value.get('user_id'),
)
if isinstance(token_value, str) and token_value:
return ProviderToken(token=SecretStr(token_value), user_id=None)
return None
To serialize the token instead of ********, set expose_secrets to True in the serialization context.
"""
if github_token is None:
return None
@model_validator(mode='before')
@classmethod
def convert_provider_tokens(cls, data: dict | object) -> dict | object:
"""Convert provider tokens from JSON format to SecretStore format."""
if not isinstance(data, dict):
return data
context = info.context
if context and context.get('expose_secrets', False):
return github_token.get_secret_value()
secrets_store = data.get('secrets_store')
if not isinstance(secrets_store, dict):
return data
return pydantic_encoder(github_token)
tokens = secrets_store.get('provider_tokens')
if not isinstance(tokens, dict):
return data
converted_tokens = {}
for token_type_str, token_value in tokens.items():
if not token_value:
continue
try:
token_type = ProviderType(token_type_str)
except ValueError:
continue
provider_token = cls._convert_token_value(token_type, token_value)
if provider_token:
converted_tokens[token_type] = provider_token
data['secrets_store'] = SecretStore(provider_tokens=converted_tokens)
return data
@field_serializer('secrets_store')
def secrets_store_serializer(self, secrets: SecretStore, info: SerializationInfo):
"""Custom serializer for secrets store."""
return {
'provider_tokens': secrets.provider_tokens_serializer(
secrets.provider_tokens, info
)
}
@staticmethod
def from_config() -> Settings | None:
@@ -73,7 +121,7 @@ class Settings(BaseModel):
llm_api_key=llm_config.api_key,
llm_base_url=llm_config.base_url,
remote_runtime_resource_factor=app_config.sandbox.remote_runtime_resource_factor,
github_token=None,
provider_tokens={},
)
return settings
@@ -84,14 +132,12 @@ class POSTSettingsModel(Settings):
"""
unset_github_token: bool | None = None
github_token: str | None = (
None # This is a string because it's coming from the frontend
)
# Override provider_tokens to accept string tokens from frontend
provider_tokens: dict[str, str] = {}
# Override the serializer for the GitHub token to handle the string input
@field_serializer('github_token')
def github_token_serializer(self, github_token: str | None):
return github_token
@field_serializer('provider_tokens')
def provider_tokens_serializer(self, provider_tokens: dict[str, str]):
return provider_tokens
class GETSettingsModel(Settings):

View File

@@ -12,25 +12,36 @@ from openhands.utils.async_utils import wait_all
class ConversationStore(ABC):
"""
Storage for conversation metadata. May or may not support multiple users depending on the environment
"""
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
@abstractmethod
async def save_metadata(self, metadata: ConversationMetadata) -> None:
"""Store conversation metadata"""
"""Store conversation metadata."""
@abstractmethod
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
"""Load conversation metadata"""
"""Load conversation metadata."""
async def validate_metadata(
self, conversation_id: str, user_id: str, github_user_id: str
) -> bool:
"""Validate that conversation belongs to the current user."""
# TODO: remove github_user_id after transition to Keycloak is complete.
metadata = await self.get_metadata(conversation_id)
if (not metadata.user_id and not metadata.github_user_id) or (
metadata.user_id != user_id and metadata.github_user_id != github_user_id
):
return False
else:
return True
@abstractmethod
async def delete_metadata(self, conversation_id: str) -> None:
"""delete conversation metadata"""
"""Delete conversation metadata."""
@abstractmethod
async def exists(self, conversation_id: str) -> bool:
"""Check if conversation exists"""
"""Check if conversation exists."""
@abstractmethod
async def search(
@@ -49,6 +60,6 @@ class ConversationStore(ABC):
@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
"""Get a store for the user represented by the token given"""

View File

@@ -7,7 +7,7 @@ class ConversationValidator:
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
async def validate(self, conversation_id: str, cookies_str: str):
return None
return None, None
conversation_validator_cls = os.environ.get(

View File

@@ -85,8 +85,8 @@ class FileConversationStore(ConversationStore):
try:
conversations.append(await self.get_metadata(conversation_id))
except Exception:
logger.error(
f'Error loading conversation: {conversation_id}',
logger.warning(
f'Could not load conversation metadata: {conversation_id}',
)
conversations.sort(key=_sort_key, reverse=True)
conversations = conversations[start:end]
@@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
@classmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> FileConversationStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileConversationStore(file_store)

View File

@@ -5,6 +5,7 @@ from datetime import datetime, timezone
@dataclass
class ConversationMetadata:
conversation_id: str
user_id: str | None
github_user_id: str | None
selected_repository: str | None
selected_branch: str | None = None

Some files were not shown because too many files have changed in this diff Show More