Compare commits

...

38 Commits

Author SHA1 Message Date
OpenHands ba25b02978 Fix issue #4735: Update msw mocks (#4736) 2024-11-04 16:58:56 +00:00
Xingyao Wang 966da7b7c8 feat(agent, CodeAct 2.2): native CodeAct support for Browsing (#4667)
Co-authored-by: tofarr <tofarr@gmail.com>
2024-11-05 00:27:27 +08:00
sp.wack f0af90bff3 fix(frontend): Always return user is authed if mode is oss (#4733) 2024-11-04 16:24:23 +00:00
Engel Nyst 1638968509 History microfixes (#4728) 2024-11-04 16:37:22 +01:00
Robert Brennan 250fcbe62c Various async fixes (#4722) 2024-11-04 10:08:09 -05:00
sp.wack 0595d2336a feat: Analytics with PostHog (#4655) 2024-11-04 09:57:56 +00:00
sp.wack 387c8f1df3 feat(frontend): Make loader synchronous (#4689) 2024-11-04 11:26:30 +02:00
Polygons1 f6c2b287bc Fix for #4717 (#4721) 2024-11-04 08:24:00 +08:00
Xingyao Wang ab188d026d Revert "Fix permissions on __init__.py" (#4718) 2024-11-04 05:10:43 +08:00
Robert Brennan 316fc260f6 Fix list-files async calls (#4720)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
2024-11-03 10:52:53 -08:00
Robert Brennan aab7fa483b Fix permissions on __init__.py (#4713) 2024-11-03 22:14:42 +08:00
Rohit Malhotra 496364ce53 Adding PR label trigger for openhands-resolver (#4712) 2024-11-02 20:19:30 -04:00
Ryan H. Tran 4446d3180f fix: use None check instead of falsy (#4705) 2024-11-02 12:44:03 -04:00
Robert Brennan 7b8241e424 fix auth when there are no allow lists (#4707) 2024-11-02 16:25:35 +00:00
Abhijeetsingh Meena 8857f02083 [Eval] DiscoveryBench OpenHands Integration (#4627)
Signed-off-by: Abhijeetsingh Meena <abhijeet040403@gmail.com>
Co-authored-by: Harshit Surana <surana.h@gmail.com>
2024-11-02 07:24:34 -04:00
Xingyao Wang 1747b3d6b2 fix: prompt caching (#4704) 2024-11-02 07:21:21 -04:00
Robert Brennan 36623a16da Minor auth fixes (#4699) 2024-11-01 18:33:29 -07:00
OpenHands 9d3b77bffc Fix issue #4695: [Bug]: Dependabot PRs fail on "Update PR Description" github action step (#4697) 2024-11-01 18:32:31 -07:00
OpenHands 2682518d0e Fix issue #4692: [Bug]: Slack link no longer working (#4693) 2024-11-01 18:34:20 -05:00
Robert Brennan b27fabe504 Add Google Sheets integration for GitHub user verification (#4671)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2024-11-01 15:17:15 -07:00
Xingyao Wang adf7ab5849 fix: handle the case where LLM assistant return None instead of empty string (#4690) 2024-11-01 19:13:01 +00:00
Robert Brennan 456998175f Fix authentication (#4686) 2024-11-01 10:54:06 -07:00
Graham Neubig b4afd9f170 Update README.md w/ github resolver link (#4679) 2024-11-01 13:07:35 +00:00
sp.wack 73c7375b92 fix(frontend): Prevent editor from changing width unpredictably (#4659) 2024-11-01 14:04:39 +02:00
tofarr 6414b1af6e Fix agent session error in logs (#4669) 2024-11-01 10:50:56 +08:00
tofarr dd55290f4e Fix : app unresponsive on startup (#4668) 2024-10-31 14:30:33 -07:00
tofarr be77baea31 refactor: remove unused methods and constants from Session class (#4662)
Co-authored-by: openhands <openhands@all-hands.dev>
2024-10-31 14:55:37 -06:00
Robert Brennan a812e2b5f1 Add cookie-based authentication to all routes (#4642)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
2024-10-31 12:18:42 -07:00
tofarr 4ebff5aaf3 Fix unawaited (#4665) 2024-10-31 19:16:37 +00:00
Engel Nyst 0687608feb [Arch proposal] ENVIRONMENT event source (#4584)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
2024-11-01 02:33:13 +08:00
Ziru "Ron" Chen db4e1dbbec [eval] Add ScienceAgentBench. (#4645)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
2024-11-01 02:30:55 +08:00
Robert Brennan 9442e4f9e3 dont run pr update on forks (#4663) 2024-11-01 01:55:50 +08:00
Robert Brennan e17f7b22a6 Remove hidden commands from feedback (#4597)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2024-10-31 08:49:47 -07:00
mamoodi ce6939fc0d Release 0.12.0 - Pending Release Notes Prep (#4650) 2024-10-31 13:14:01 +00:00
Xingyao Wang 4705ef9ec2 chore: do not include "status" dict in share-openhands (#4620) 2024-10-31 20:35:35 +08:00
Xingyao Wang 9c2b48ff5d fix(eval): SWE-Bench instance with upper-case instance id (#4649) 2024-10-30 21:24:18 +00:00
Robert Brennan 87906b96a7 Add job to update PR description with docker run command (#4550)
Co-authored-by: openhands <openhands@all-hands.dev>
2024-10-30 16:42:03 -04:00
Xingyao Wang c0a0d46eb2 test(runtime) #4623: file permission when running the file_editor (#4628)
Co-authored-by: openhands <openhands@all-hands.dev>
2024-10-31 04:34:34 +08:00
98 changed files with 3693 additions and 921 deletions
+46
View File
@@ -399,3 +399,49 @@ jobs:
run: |
echo "Some runtime tests failed or were cancelled"
exit 1
update_pr_description:
name: Update PR Description
if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]'
needs: [ghcr_build_runtime]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Get short SHA
id: short_sha
run: echo "SHORT_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-7)" >> $GITHUB_OUTPUT
- name: Update PR Description
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
REPO: ${{ github.repository }}
SHORT_SHA: ${{ steps.short_sha.outputs.SHORT_SHA }}
run: |
echo "updating PR description"
DOCKER_RUN_COMMAND="docker run -it --rm \
-p 3000:3000 \
-v /var/run/docker.sock:/var/run/docker.sock \
--add-host host.docker.internal:host-gateway \
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/all-hands-ai/runtime:$SHORT_SHA-nikolaik \
--name openhands-app-$SHORT_SHA \
ghcr.io/all-hands-ai/runtime:$SHORT_SHA"
PR_BODY=$(gh pr view $PR_NUMBER --json body --jq .body)
if echo "$PR_BODY" | grep -q "To run this PR locally, use the following command:"; then
UPDATED_PR_BODY=$(echo "${PR_BODY}" | sed -E "s|docker run -it --rm.*|$DOCKER_RUN_COMMAND|")
else
UPDATED_PR_BODY="${PR_BODY}
---
To run this PR locally, use the following command:
\`\`\`
$DOCKER_RUN_COMMAND
\`\`\`"
fi
echo "updated body: $UPDATED_PR_BODY"
gh pr edit $PR_NUMBER --body "$UPDATED_PR_BODY"
+2
View File
@@ -3,6 +3,8 @@ name: Resolve Issues with OpenHands
on:
issues:
types: [labeled]
pull_request:
types: [labeled]
jobs:
call-openhands-resolver:
+1
View File
@@ -174,6 +174,7 @@ evaluation/bird/data
evaluation/gaia/data
evaluation/gorilla/data
evaluation/toolqa/data
evaluation/scienceagentbench/benchmark
# frontend
+7 -6
View File
@@ -12,7 +12,7 @@
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
<br/>
<a href="https://join.slack.com/t/opendevin/shared_invite/zt-2oikve2hu-UDxHeo8nsE69y6T7yFX_BA"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2tom0er4l-JeNUGHt_AxpEfIBstbLPiw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
<br/>
@@ -38,15 +38,15 @@ See the [Installation](https://docs.all-hands.dev/modules/usage/installation) gu
system requirements and more information.
```bash
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik
docker run -it --rm --pull=always \
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik \
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
-v /var/run/docker.sock:/var/run/docker.sock \
-p 3000:3000 \
--add-host host.docker.internal:host-gateway \
--name openhands-app \
docker.all-hands.dev/all-hands-ai/openhands:0.11
docker.all-hands.dev/all-hands-ai/openhands:0.12
```
You'll find OpenHands running at [http://localhost:3000](http://localhost:3000)!
@@ -59,7 +59,8 @@ works best, but you have [many options](https://docs.all-hands.dev/modules/usage
You can also [connect OpenHands to your local filesystem](https://docs.all-hands.dev/modules/usage/runtimes),
run OpenHands in a scriptable [headless mode](https://docs.all-hands.dev/modules/usage/how-to/headless-mode),
or interact with it via a [friendly CLI](https://docs.all-hands.dev/modules/usage/how-to/cli-mode).
interact with it via a [friendly CLI](https://docs.all-hands.dev/modules/usage/how-to/cli-mode),
or run it on tagged issues with [a github action](https://github.com/All-Hands-AI/OpenHands-resolver).
Visit [Installation](https://docs.all-hands.dev/modules/usage/installation) for more information and setup instructions.
@@ -92,7 +93,7 @@ For details, please check [CONTRIBUTING.md](./CONTRIBUTING.md).
Whether you're a developer, a researcher, or simply enthusiastic about OpenHands, we'd love to have you in our community.
Let's make software engineering better together!
- [Slack workspace](https://join.slack.com/t/opendevin/shared_invite/zt-2oikve2hu-UDxHeo8nsE69y6T7yFX_BA) - Here we talk about research, architecture, and future development.
- [Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2tom0er4l-JeNUGHt_AxpEfIBstbLPiw) - Here we talk about research, architecture, and future development.
- [Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
## 📈 Progress
+2 -2
View File
@@ -50,6 +50,7 @@ LLM_API_KEY="sk_test_12345"
```bash
docker run -it \
--pull=always \
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
-e SANDBOX_USER_ID=$(id -u) \
-e WORKSPACE_MOUNT_PATH=$WORKSPACE_BASE \
-e LLM_API_KEY=$LLM_API_KEY \
@@ -58,7 +59,7 @@ docker run -it \
-v /var/run/docker.sock:/var/run/docker.sock \
--add-host host.docker.internal:host-gateway \
--name openhands-app-$(date +%Y%m%d%H%M%S) \
docker.all-hands.dev/all-hands-ai/openhands:0.11 \
docker.all-hands.dev/all-hands-ai/openhands:0.12 \
python -m openhands.core.cli
```
@@ -107,4 +108,3 @@ Expected Output:
```bash
🤖 An error occurred. Please try again.
```
+2 -2
View File
@@ -44,6 +44,7 @@ LLM_API_KEY="sk_test_12345"
```bash
docker run -it \
--pull=always \
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
-e SANDBOX_USER_ID=$(id -u) \
-e WORKSPACE_MOUNT_PATH=$WORKSPACE_BASE \
-e LLM_API_KEY=$LLM_API_KEY \
@@ -52,7 +53,6 @@ docker run -it \
-v /var/run/docker.sock:/var/run/docker.sock \
--add-host host.docker.internal:host-gateway \
--name openhands-app-$(date +%Y%m%d%H%M%S) \
docker.all-hands.dev/all-hands-ai/openhands:0.11 \
docker.all-hands.dev/all-hands-ai/openhands:0.12 \
python -m openhands.core.main -t "write a bash script that prints hi"
```
+3 -3
View File
@@ -11,15 +11,15 @@
The easiest way to run OpenHands is in Docker.
```bash
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik
docker run -it --rm --pull=always \
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.11-nikolaik \
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.12-nikolaik \
-v /var/run/docker.sock:/var/run/docker.sock \
-p 3000:3000 \
--add-host host.docker.internal:host-gateway \
--name openhands-app \
docker.all-hands.dev/all-hands-ai/openhands:0.11
docker.all-hands.dev/all-hands-ai/openhands:0.12
```
You can also run OpenHands in a scriptable [headless mode](https://docs.all-hands.dev/modules/usage/how-to/headless-mode), as an [interactive CLI](https://docs.all-hands.dev/modules/usage/how-to/cli-mode), or using the [OpenHands GitHub Action](https://docs.all-hands.dev/modules/usage/how-to/github-action).
+37
View File
@@ -0,0 +1,37 @@
# DiscoveryBench with OpenHands
[DiscoveryBench](https://github.com/allenai/discoverybench/) [(Paper)](https://arxiv.org/abs/2407.01725v1) contains 264 tasks collected across 6 diverse domains, such as biology, economics, and sociology. It incorporates discovery workflows from published papers to approximate the real-world challenges faced by researchers.
<p align="center">
<a href="[https://github.com/allenai/discoverybench](https://github.com/allenai/discoverybench)">
<img src="https://raw.githubusercontent.com/allenai/discoverybench/refs/heads/main/assets/discoverybench-openhands-teaser.png" width="100%" alt="DiscoveryBench Background" />
</a>
</p>
## Setup Environment and LLM Configuration
1. Please follow instructions mentioned [here](https://github.com/openlocus/OpenHands/blob/discoverybench-openhands-integration/evaluation/README.md#setup) to setup OpenHands development environment and LLMs locally
2. Execute the bash script to start DiscoveryBench Evaluation
```
./evaluation/discoverybench/scripts/run_infer.sh [YOUR MODEL CONFIG]
```
Replace `[YOUR MODEL CONFIG]` with any model the model that you have set up in `config.toml`
## Run Inference on DiscoveryBench Instances
When the `run_infer.sh` script is started, it will automatically pull the latest DiscoveryBench instances & set up the agent environment. The OpenHands agent is invoked to process the task within this environment, producing a hypothesis. We then evaluate it against the “gold” hypothesis provided by DiscoveryBench. The evaluation result, along with the agent chat history is logged to `output.jsonl` under `evaluation_outputs`.
```
./evaluation/discoverybench/scripts/run_infer.sh [MODEL_CONFIG] [GIT_COMMIT] [AGENT] [EVAL_LIMIT] [NUM_WORKERS]
```
- `MODEL_CONFIG`: Name of the model you want to evaluate with
- `GIT_COMMIT`: This should be the git commit hash or release tag for OpenHands, e.g., HEAD or a specific tag like 0.6.2.
- `AGENT`: Use CoderActAgent, right now it only supports that.
- `EVAL_LIMIT`: Number of samples to evaluate.
- `NUM_WORKERS`: Number of workers to parallelize the evaluation process.
@@ -0,0 +1,7 @@
## DiscoveryBench Evaluation Utils
- **`eval_w_subhypo_gen.py`**: Implements the DiscoveryBench logic for evaluating agent-generated hypotheses.
- **`lm_utils.py`**: Provides utility functions necessary for the evaluation process.
- **`openai_helpers.py`**: Includes helper functions for OpenAI-related tasks.
- **`openai_semantic_gen_prompts.py`**: Contains prompts used for semantic generation.
- **`response_parser.py`**: Handles the parsing of agent-generated hypotheses.
@@ -0,0 +1,538 @@
import json
import logging
from openai import OpenAI
from .lm_utils import run_chatgpt_query_multi_turn
from .openai_helpers import get_response
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def get_score_from_answer(type, answer):
if type == 'context':
answer = answer.replace('Answer:', '').strip()
if answer.startswith('A)'):
return 1.0
elif answer.startswith('B)'):
return 0.0
return -1.0
elif type == 'var':
try:
var_json = json.loads(answer)
# print(f"var_json:{var_json}")
p = 0.0
r = 0.0
f1 = 0.0
if var_json['sizeB']:
p = var_json['intersection'] / var_json['sizeB']
if var_json['sizeA']:
r = var_json['intersection'] / var_json['sizeA']
if p > 0.0 and r > 0.0:
f1 = (2 * p * r) / (p + r)
else:
f1 = 0.0
eval_rec = {
'p': p,
'r': r,
'f1': f1,
'sizeA': var_json['sizeA'],
'sizeB': var_json['sizeB'],
'intersection': var_json['intersection'],
'explanation': var_json['explanation'],
}
print(f'var_eval: {eval_rec}')
return eval_rec
except Exception: # COMMENT: added Exception
return {'p': -1.0, 'r': -1.0, 'f1': -1.0}
elif type == 'rel':
print(answer)
rel_json = json.loads(answer)
answer_str = rel_json['answer'].strip()
if answer_str.startswith('A') or 'very similar' in answer_str:
return 1.0
elif (
answer_str.startswith('B') or 'similar but general than HypoA' in answer_str
):
return 0.5
elif answer_str.startswith('C') or 'different' in answer_str:
return 0.0
return -1.0
return -1.0
def ask_dimension_question(
query,
gold_hypo,
gold_workflow,
gen_hypo,
gen_workflow,
dataset_meta,
llm_used,
dimension,
dataset_type,
use_column_metadata=True,
):
dimension_question = ''
answer = ''
score = 0.0
if dimension == 'var':
score = {'p': -1.0, 'r': -1.0, 'f1': -1.0}
num_tokens = 256
num_retries = 1
json_response = False
messages = [
{
'role': 'system',
'content': 'You are an AI assistant that helps evaluate a data-driven hypothesis. You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
},
]
if dimension == 'context':
dimension_question = """\
Question: Is HypoB defined in the same context as HypoA?
(Context refers to assumptions/stratification under which the hypotheses are defined.)
Options: A) same B) different
What is your answer?"""
elif dimension == 'var':
dimension_question = """\
Question: For both HypoA and HypoB, what are the different variables found in the hypotheses? \
Return your answer as a JSON object in the following format:
```json
{{
"sizeA": num of variables used in HypoA
"sizeB": num of variables used in HypoB
"intersection": num of variables common in HypoA and HypoB. Use *fuzzy matching* to determine intersection, accounting for paraphrases or slightly different surface forms
"explanation": a short text explanation about the variables
}}```
Answer:"""
num_tokens = 512
num_retries = 1
json_response = True
elif dimension == 'rel':
dimension_question = """\
Question: Does HypoB exhibit the same relation as HypoA?
Compare using following example hierarchy of relationships (based on specificity): \
"there exists a relationship" > "positive relationship" > "positive AND (linear OR quadratic)" > "positive AND linear".
Options: A) very similar B) similar but general than HypoA C) different
Return your answer as a JSON object in the following format:
```json
{{
"answer": one of the options from A) very similar B) similar but general than HypoA C) different
"explanation": a short text explanation about the relationship comparison
}}```
Answer:"""
num_tokens = 512
num_retries = 1
json_response = True
datasets_json = prepare_dataset_metadata_json(
dataset_meta, dataset_type=dataset_type, use_column_metadata=use_column_metadata
)
dimension_question_str = f"""\
You are going to compare two natural-language hypotheses HypoA and HypoB accompanied with optional workflows: WorkflowA for HypoA and WorkflowB for HypoB. \
Both the hypotheses answer the natural language query "QUERY" over the dataset(s) described by dataset description(s) and column description(s) below. \
Compare HypoA and HypoB in terms of three aspects: Contexts, Variables, and Relations. \
E.g., for the hypothesis "From 1995 to 2009, the number of sandhill cranes around the tundra (Indigilka River) surged by an astounding ~10X":
* Contexts refer to stratification of the data under which the given hypothesis is True. E.g., "For all women", "From 1995 to 2009".
* Variables refer to the set of variables (either dependent or independent) that are mentioned in the hypothesis. E.g., number of sandhill cranes, location.
* Relations refer to the form of relation between the variables. E.g., "surged by ~10x".
Answer following questions for a given pair of hypotheses, HypoA and HypoB, along with an explanation grounded on the QUERY and the DATASET(S).
Here is the metadata for the task:
```json
{{
"datasets": {datasets_json},
"query": {query},
"HypoA": {gold_hypo},
"WorkflowA": {gold_workflow},
"HypoB": {gen_hypo},
"WorkflowB": {gen_workflow}
}}
```
{dimension_question}"""
messages.append({'role': 'user', 'content': dimension_question_str})
for retry in range(num_retries):
response = run_chatgpt_query_multi_turn(
messages=messages,
model_name=llm_used,
max_tokens=num_tokens,
temperature=0, # 0 for greedy best decoding
json_response=json_response,
)
if response is not None: # COMMENT: changed from != to is not
break
if response is not None: # COMMENT: changed from != to is not
answer = response.choices[0].message.content.strip()
score = get_score_from_answer(type=dimension, answer=answer)
return dimension_question, answer, score
def prepare_dataset_metadata_json(dataset_meta, dataset_type, use_column_metadata=True):
if dataset_meta is None: # COMMENT: changed from == to is None
return [
{
'dataset_description': '',
'columns': [],
}
]
datasets_json = []
if dataset_type == 'real':
for d in dataset_meta['datasets']:
datasets_json.append(
{
'dataset_description': d['description'],
'columns': [
{'name': col['name'], 'description': col['description']}
for col in d['columns']['raw']
]
if use_column_metadata
else [],
}
)
else:
for d in dataset_meta['datasets']:
datasets_json.append(
{
'dataset_description': d['description'],
'columns': [
{'name': col['name'], 'description': col['description']}
for col in d['columns']
]
if use_column_metadata
else [],
}
)
return datasets_json
def get_sub_hypotheses(
query,
hypo,
workflow,
dataset_meta,
llm_used,
dataset_type,
use_column_metadata=True,
):
client = OpenAI()
extraction_prompt = """\
Given a set of dataset columns, a ground-truth hypothesis, and the analysis workflow used, your task is to extract three dimensions that define the hypothesis: Context, Variables, and Relations. \
Here are the definitions for these dimensions:
- Contexts: Boundary conditions that limit the scope of a hypothesis. E.g., “for men over \
the age of 30”, “in Asia and Europe”. If the context applies to the full dataset, then extract the context from the dataset_descrption.
- Variables: Known concepts that interact in a meaningful way under a given context to \
produce the hypothesis. E.g., gender, age, income, or "None" if there is no interacting variable.
- Relations: Interactions between a given set of variables under a given context to produce \
the hypothesis. E.g., “quadratic relationship”, “inversely proportional”, piecewise conditionals, \
or "None" if there is no interacting relationship.
Make sure to only use the information present in the hypothesis and the workflow. Do not add any new information. \
For each dimension, be specific, and do not omit any important details.
Here is the metadata for the task:
```json
{
"datasets": %s,
"hypothesis": "%s",
"workflow": "%s"
}
```
Return your answer as a JSON object in the following format:
```json
{
"sub_hypo": [
{
"text": the hypothesis in natural language,
"context": a short text description of the context of the hypothesis,
"variables": a list of columns involved in the hypothesis,
"relations": a short text description of the relationship between the variables of the hypothesis
},
...
]
}```
"""
datasets_json = prepare_dataset_metadata_json(
dataset_meta, dataset_type, use_column_metadata=use_column_metadata
)
_prompt = extraction_prompt % (datasets_json, hypo, workflow)
sub_hypo_json = get_response(client, _prompt, model=llm_used, max_retry=1)
if sub_hypo_json is not None: # COMMENT: changed from != to is not
# print(f"full hypothesis: {hypo}")
print(f'sub_hypo_json: {sub_hypo_json}')
else:
sub_hypo_json = {
'sub_hypo': [],
}
sub_hypo_json['full_hypo'] = hypo
return sub_hypo_json
def match_context_with_gpt(
gold_hyp, gold_context, pred_hyp, pred_context, model='gpt-3.5-turbo'
):
prompt = f"""\
Given a gold hypothesis, a gold context, a predicted hypothesis, and a predicted context, your task is \
to determine if the predicted context semantically matches the ground-truth context. \
Here is the definition for Context: Boundary conditions that limit the scope of a sub-hypothesis. E.g., “for men over the age of 30”, “in Asia and Europe”. If the context applies to the full dataset, then the context is derived from the dataset_descrption. \
Here is the definition for Context: Boundary conditions that limit the scope of a sub-hypothesis. E.g., “for men over the age of 30”, “in Asia and Europe”. If the context applies to the full dataset, then the context is derived from the dataset_descrption. \
If the predicted context matches the gold context, return true, otherwise return false.
If both gold and predicted hypotheses are defined over the context of the full dataset, then also return true.
If both gold and predicted hypotheses are defined over the context of the full dataset, then also return true.
Here is the metadata for the task:
```json
{{
"gold_hypothesis": "{gold_hyp}",
"gold_context": "{gold_context}",
"predicted_hypothesis": "{pred_hyp}",
"predicted_context": "{pred_context}"
}}
```
Return your answer as a JSON object in the following format:
```json
{{
"match": true or false
}}
```"""
client = OpenAI()
output = get_response(client, prompt, model=model)
return output.get('match', False)
def is_matching_context(gold_hyp, gold_context, pred_hyp, pred_context, llm_used):
if gold_context == pred_context:
return True
if 'None' in [gold_context, pred_context]:
return False
return match_context_with_gpt(
gold_hyp, gold_context, pred_hyp, pred_context, model=llm_used
)
def run_eval_gold_vs_gen_NL_subhypo(
query,
gold_hypo,
gold_workflow,
gen_hypo,
gen_workflow,
dataset_meta,
llm_used,
context_score,
dataset_type,
use_column_metadata=True,
):
# GPT-4 based evaluation to evaluate generated hypothesis in terms of context, variables, relation
eval_rec = {
'query': query,
'HypoA': gold_hypo,
'WorkflowA': gold_workflow,
'HypoB': gen_hypo,
'WorkflowB': gen_workflow,
}
for dimension in ['var', 'rel']:
question, answer, score = ask_dimension_question(
query,
gold_hypo,
gold_workflow,
gen_hypo,
gen_workflow,
dataset_meta,
llm_used,
dimension=dimension,
dataset_type=dataset_type,
use_column_metadata=use_column_metadata,
)
eval_rec[dimension] = {'question': question, 'answer': answer, 'score': score}
eval_rec['context'] = context_score
eval_rec['accuracy_score'] = (
1.0
* eval_rec['context']['score']
* eval_rec['var']['score']['f1']
* eval_rec['rel']['score']
)
return eval_rec
def run_eval_gold_vs_gen_NL_hypo_workflow(
query,
gold_hypo,
gold_workflow,
gen_hypo,
gen_workflow,
dataset_meta,
llm_used,
dataset_type,
use_column_metadata=True,
):
# Input: Dataset Metadata, Query, Gold {Hg, Wg}, Predicted {Hp, Wp}
# Output: eval_rec json includes final_score
# Procedure:
# Dataset Metadata, Query, Gold {Hg, Wg}, Pred {Hg, Wg}
# Gold: [Hg1, Hg2] (compute on the fly) Hg1 is a NL form of subhypothesis
# Predicted: [Hp1, Hp2] (compute on the fly)
# Compute Intersection: [(Hg_i, Hp_j), …] # tuples of (gold,pred) that matched with context (do this w/o explicit extraction)
# # filter so that a gold context and a predicted context are only attached to one tuple
# Compute recall_context (programmatically)
# r_v_list = []
# For (Hg_i, Hp_j) in the intersection:
# With Hg_i, Hp_j in NL, ask GPT4 → #variables and #intersection and a paragraph explanation and programmatically calculate f1_v
# Hg_i, Hp_j in NL, ask GPT4 → matching score (0, 0.5 or 1) : A) very similar B) similar but general than HypoA C) different + explanation
# r_v_list ← f1_v * score_r
# accuracy_score = mean(r_v_list)
# score = [ recall_context * mean over predicted context(context_score * var_score *rel_score )]
# recall_context = 1.0 # COMMENT: never used
eval_rec = {
'query': query,
'HypoA': gold_hypo,
'WorkflowA': gold_workflow,
'HypoB': gen_hypo,
'WorkflowB': gen_workflow,
}
gold_sub_hypo_json = get_sub_hypotheses(
query=query,
hypo=gold_hypo,
workflow=gold_workflow,
dataset_meta=dataset_meta,
llm_used=llm_used,
dataset_type=dataset_type,
use_column_metadata=use_column_metadata,
)
if len(gold_sub_hypo_json['sub_hypo']) == 0:
gold_sub_hypo_json['sub_hypo'] = [
{
'text': gold_hypo,
'context': 'None',
'variables': [],
'relations': '',
'explanation': 'unable to segment',
}
]
print(f'gold_sub_hypo_json: {gold_sub_hypo_json}')
gen_sub_hypo_json = get_sub_hypotheses(
query=query,
hypo=gen_hypo,
workflow=gen_workflow,
dataset_meta=dataset_meta,
llm_used=llm_used,
dataset_type=dataset_type,
use_column_metadata=use_column_metadata,
)
if len(gen_sub_hypo_json['sub_hypo']) == 0:
gen_sub_hypo_json['sub_hypo'] = [
{
'text': gen_hypo,
'context': 'None',
'variables': [],
'relations': '',
'explanation': 'unable to segment',
}
]
print(f'gen_sub_hypo_json: {gen_sub_hypo_json}')
eval_rec['gold_sub_hypo'] = gold_sub_hypo_json
eval_rec['gen_sub_hypo'] = gen_sub_hypo_json
gold_subh_covered = []
gen_subh_to_gold_subh = dict()
gen_gold_subh_to_context = dict()
for p_id, gen_subh in enumerate(gen_sub_hypo_json['sub_hypo']):
gen_subh_to_gold_subh[p_id] = -1
for g_id, gold_subh in enumerate(gold_sub_hypo_json['sub_hypo']):
if g_id in gold_subh_covered:
continue
# match context
context_bool = is_matching_context(
gold_subh['text'],
gold_subh.get('context', ''),
gen_subh['text'],
gen_subh.get('context', ''),
llm_used,
)
if context_bool:
context_score = 1.0
else:
context_score = 0.0
if context_score == 1.0: # match only when context_score = 1.0
gen_subh_to_gold_subh[p_id] = g_id
gold_subh_covered.append(g_id)
gen_gold_subh_to_context[f'P{p_id}||G{g_id}'] = {
'question': f"""Comapring: GoldH: {gold_subh["text"]}, GoldC: {gold_subh['context']}\nGenH: {gen_subh['text']}, GenC: {gen_subh['context']}""",
'answer': context_bool,
'score': context_score,
}
break
print(f'gen_subh_to_gold_subh: {gen_subh_to_gold_subh}')
eval_rec['gen_subh_to_gold_subh'] = gen_subh_to_gold_subh
eval_rec['gold_subh_covered'] = gold_subh_covered
matched_gold_gen_subh_evals = dict()
sum_accuracy_score = 0.0
for p_id, g_id in gen_subh_to_gold_subh.items():
if g_id >= 0:
key = f'P{p_id}||G{g_id}'
context_score = gen_gold_subh_to_context[key]
subh_eval_rec = run_eval_gold_vs_gen_NL_subhypo(
query,
gold_hypo,
gold_workflow,
gen_hypo,
gen_workflow,
dataset_meta,
llm_used,
context_score,
dataset_type=dataset_type,
use_column_metadata=use_column_metadata,
)
sum_accuracy_score += subh_eval_rec['accuracy_score']
matched_gold_gen_subh_evals[key] = subh_eval_rec
eval_rec['matched_gold_gen_subh_evals'] = matched_gold_gen_subh_evals
eval_rec['recall_context'] = (
len(gold_subh_covered) / len(gold_sub_hypo_json['sub_hypo'])
if len(gold_sub_hypo_json['sub_hypo'])
else 0.0
)
mean_accuracy_score = (
sum_accuracy_score / len(gen_subh_to_gold_subh)
if len(gen_subh_to_gold_subh)
else 0.0
)
eval_rec['mean_accuracy_score'] = mean_accuracy_score
final_score = eval_rec['recall_context'] * mean_accuracy_score
eval_rec['final_score'] = final_score
print(f'eval_rec: {json.dumps(eval_rec, indent=2)}')
return eval_rec
@@ -0,0 +1,64 @@
import os
import sys
import time
from openai import OpenAI
from tenacity import (
retry,
stop_after_attempt, # type: ignore
wait_random_exponential, # type: ignore
)
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
Model = Literal['gpt-4', 'gpt-3.5-turbo', 'text-davinci-003']
OpenAI.api_key = os.getenv('OPENAI_API_KEY')
OPENAI_GEN_HYP = {
'temperature': 0,
'max_tokens': 250,
'top_p': 1.0,
'frequency_penalty': 0,
'presence_penalty': 0,
}
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def run_chatgpt_query_multi_turn(
messages,
model_name='gpt-4-turbo', # pass "gpt4" for more recent model output
max_tokens=256,
temperature=0.0,
json_response=False,
):
response = None
num_retries = 3
retry = 0
while retry < num_retries:
retry += 1
try:
client = OpenAI()
if json_response:
response = client.chat.completions.create(
model=model_name,
response_format={'type': 'json_object'},
messages=messages,
**OPENAI_GEN_HYP,
)
else:
response = client.chat.completions.create(
model=model_name, messages=messages, **OPENAI_GEN_HYP
)
break
except Exception as e:
print(e)
print('GPT error. Retrying in 2 seconds...')
time.sleep(2)
return response
@@ -0,0 +1,190 @@
import json
def OPENAI_TOPIC_GEN_MESSAGES(n=10):
return [
{
'role': 'system',
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
},
{
'role': 'user',
'content': f'Given `n`, come up with a list of `n` distinct topics and their descriptions. The topics can be absolutely anything. Be as creative as possible. Return your answer as a JSON object. \n\nFor example, for `n`=3, a valid answer might be:\n```json\n{{"topics": [\n {{"id": 1, "topic": "cooking", "description": "Related to recipes, ingredients, chefs, etc."}},\n {{"id": 2, "topic": "sports", "description": "Related to players, stadiums, trophies, etc."}},\n {{"id": 3, "topic": "antiquing", "description": "Related to unique items, history, etc."}}\n]}}```\n\nNow, give me a list for `n`={n}. Remember, pick diverse topics from everything possible. No consecutive topics should be broadly similar. Directly respond with the answer JSON object.',
},
]
OPENAI_GEN_HYP = {
'temperature': 1.0,
'max_tokens': 4096,
'top_p': 1.0,
'frequency_penalty': 0,
'presence_penalty': 0,
}
def OPENAI_SEMANTICS_GEN_MESSAGES(dependent, relationship, domain, domain_desc):
return [
{
'role': 'system',
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
},
{
'role': 'user',
'content': f'Given the true relationship in a dataset and a given domain, your task is to come up with an interpretation of some real-world concepts that the relationship could be modeling from the provided domain. It\'s okay to be wrong, but suggest something reasonable. Try as much as possible to make sure that the TARGET is actually derivable from the other variables. Give your answer as a JSON object. Here\'s an example:\n\nRelationship for x2 = "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)"\nDomain="Sales"\nDomain description="Related to product distribution, revenues, marketing, etc."\n\nBased on this, the following real-world concepts might be applicable:\n```json\n{{\n "dependent": "x2",\n "relationship": "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)",\n "domain": "Sales",\n "trends": {{\n "x1": "Positive, cubic factor",\n "x2": "TARGET",\n "x3": "Positive, linear factor",\n "x4": "No relation",\n "x5": "Positive quadratic factor",\n "x6": "Positive, inverse quadratic factor"\n }},\n "interpretation": {{\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area", "is_target": true}},\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n}}```\n\nHere\'s a new test question:\nRelationship for {dependent} = "{relationship}"\nDomain = "{domain}"\nDomain description="{domain_desc}"\n\nRespond only with the answer JSON. Make sure that you do not forget to include the TARGET variable in the interpretation object.',
},
]
def OPENAI_SEMANTICS_GEN_W_MAP_MESSAGES(
dependent, relationship, domain, domain_desc, mapping
):
return [
{
'role': 'system',
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
},
{
'role': 'user',
'content': f'Given a partial mapping from variables to real-world concepts and a true relationship in a dataset, your task is to come up with an interpretation of real-world concepts for the variables without any assigned mapping (those starting with x). Suggest something reasonable. The dependent variable must be derivable only from the other variables in the dependent relationship. Give your answer as a JSON object. Here\'s an example:\n\nExample partial mapping and relationship:\n```json\n{{\n "domain": "Sales",\n "domain_description": "Related to product distribution, revenues, marketing, etc.",\n "variable_mapping": {{\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n }},\n "dependent_variable": "sales_area",\n "dependent_relationship": "(96.4 * pop_area ** 3) + (88.72 * x5 ** 2) + (81.96 * dist_to_distr_ctr ** -2) + (28.13 * x3) + (97.0)"\n}}```\nBased on this, an example answer would be:\n```json\n{{\n "dependent_variable": "sales_area",\n "missing_mapping": ["x3", "x5"],\n "trends": {{\n "x3": "Positive, linear factor",\n "x5": "Positive quadratic factor"\n }},\n "interpretation": {{\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }}\n}}```\n\nHere\'s a new test question:\n```json\n{{\n "domain": "{domain}",\n "domain_description": "{domain_desc}",\n "variable_mapping": {json.dumps(mapping, indent=2)},\n "dependent_variable": "{dependent}",\n "dependent_relationship": "{relationship}"\n}}```\nRespond only with the answer JSON.',
},
]
def OPENAI_SEMANTICS_GEN_SUMMARY_MESSAGES(dataset):
return [
{
'role': 'system',
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
},
{
'role': 'user',
'content': f'Given the following descriptions of the columns of a dataset, your task is to come up with a natural language overview of the dataset, which should include (1) what the dataset is about, (2) how the data was collected, (3) when the data was collected, and (3) for what purpose the data was collected. Be specific and creative.\n\nExample dataset:\n```json\n{{ \n "dataset": {{ \n "x6": {{"description": "Ancient artifact significance score", "name": "artifact_significance_score", "is_target": true}},\n "x1": {{"description": "Distance to ancient city center", "name": "dist_to_ancient_city_ctr"}},\n "x2": {{"description": "Quantity of discovered relics", "name": "relic_discovery_qty"}},\n "x3": {{"description": "Years since last archaeological expedition", "name": "years_since_exp"}},\n "x4": {{"description": "Number of artifacts in excavation site", "name": "artifact_qty"}},\n "x5": {{"description": "Soil fertility coefficient", "name": "soil_fertility_coef"}},\n "x7": {{"description": "Distance to ancient burial grounds", "name": "dist_to_burial_grounds"}},\n "x8": {{"description": "Population estimate of ancient civilization", "name": "ancient_civilization_pop_estimate"}},\n "x9": {{"description": "Temperature variation in excavation region", "name": "temp_variation"}}\n }}\n}}```\nExample description:\nThis dataset is about archaeological explorations and findings linked to ancient civilizations. The data was collected in the form of field metrics during various archaeological expeditions during the late mid-20th century. The purpose of the data collection is to evaluate the significance of ancient artifacts discovered during excavations.\n\nHere is a new test dataset.\n{json.dumps(dataset, indent=2)}\nProvide only the description.',
},
]
def OPENAI_GEN_HYPO_MESSAGES(dataset):
return [
{
'role': 'system',
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
},
{
'role': 'user',
'content': f'Given a dataset with its descriptions and the true functional relationship between its variables, your task is to generate 3 levels of hypotheses for the stated relationship in plain English. The three levels are "broad", "medium" and "narrow". Make sure that the hypotheses sound natural. *Only include concepts for variables that are present in the provided functional relationship.* Give your answer as a JSON.\n\nFor example, an example dataset might be the following:\n```json\n{{\n "domain": "cybersecurity",\n "summary": "This dataset is about measuring cybersecurity threats in a system. The data was collected by monitoring various cybersecurity metrics in a network environment. The purpose of the data collection is to assess and predict potential cybersecurity risks and vulnerabilities.",\n "variables": [\n {{\n "description": "Level of cybersecurity threat",\n "name": "cybersecurity_threat",\n "is_target": true\n }},\n {{\n "description": "Number of failed login attempts",\n "name": "failed_login_attempts"\n }},\n {{\n "description": "Amount of encrypted data",\n "name": "encrypted_data"\n }},\n {{\n "description": "Frequency of software updates",\n "name": "software_updates"\n }},\n {{\n "description": "Number of antivirus software installed",\n "name": "antivirus_software"\n }},\n {{\n "description": "Quality of firewall protection",\n "name": "firewall_quality"\n }}\n ],\n "relationship": {{\n "dependent": "cybersecurity_threat",\n "relation": "-53.5*encrypted_data**2 - 53.85*failed_login_attempts**2 + 67.75*firewall_quality - 92.16 - 36.68/software_updates**3"\n }}\n}}```\nGiven this dataset, the following is a valid answer:\n```json\n{{\n "broad": {{\n "instruction": "Be vague. Only indicate which concepts might be related but not how they are related",\n "hypothesis": "Threat to cybersecurity is influenced by several factors including the amount of encrypted data, the number of failed login attempts, the quality of the firewall, as well as how often the software is updated."\n }},\n "medium": {{\n "instruction": "Be slightly more specific. For each factor, indicate carefully whether it positively or negatively affects the relationship, but do not indicate what the exponent is.",\n "hypothesis": "Cybersecurity threat tends to decrease with the amount of data encryption, the number of failed login attempts, as well as the frequency of software updates to some extent, while improvement in the firewall quality has a positive effect."\n }},\n "narrow": {{\n "instruction": "Be specific. Communicate the concepts, whether there is a positive or negative effect (be careful), and the meaning of the exponent",\n "hypothesis": "The threat to cybersecurity interacts in a complex manner with various factors. As the amount of encrypted data increases, there is a quadratic decrease in threat. Similarly for the number of failed login attempts, there is a negative quadratic relationship. The quality of the firewall protection on the other hand demonstrates a positive and linear relationship. Finally, the frequency of software updates has an inverse cubic relationship to the threat."\n }},\n}}\n```\n\nBased on this, provide an answer for the following test dataset:\n```json\n{dataset}```\nRespond only with a JSON.',
},
]
def create_prompt(usr_msg):
return [
{
'role': 'system',
'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.',
},
{'role': 'user', 'content': usr_msg},
]
def get_response(client, prompt, max_retry=5, model='gpt-3.5-turbo', verbose=False):
n_try = 0
while n_try < max_retry:
response = client.chat.completions.create(
model=model, messages=create_prompt(prompt), **OPENAI_GEN_HYP
)
# COMMENT: changed from
# response.choices[0].message.content.strip().strip('```json').strip('```')
content = response.choices[0].message.content
cleaned_content = content.split('```json')[1].split('```')[0].strip()
output = cleaned_content
try:
response_json = json.loads(output)
return response_json
except ValueError:
if verbose:
print(f'Bad JSON output:\n\n{output}')
n_try += 1
if n_try < max_retry:
if verbose:
print('Retrying...')
else:
if verbose:
print('Retry limit reached')
return None
def get_code_fix(
client, code, error, max_retry=5, model='gpt-3.5-turbo', verbose=False
):
prompt = f"""\
Given the following code snippet and error message, provide a single-line fix for the error. \
Note that the code is going to be executed using python `eval`. \
The code should be executable and should not produce the error message. Be as specific as possible.
Here's the code and the error:
{{
"code": "{code}",
"error": "{error}"
}}
Return only a JSON object with the fixed code in the following format:
```json
{{
"fixed_code": "..."
}}"""
response = get_response(
client, prompt, max_retry=max_retry, model=model, verbose=verbose
)
return response
def get_new_hypothesis(
client, target, old, expr, cols, model='gpt-3.5-turbo', verbose=False
):
prompt = f"""\
Given a target column from a dataset, a pandas expression to derive the column from existing columns, a list of \
existing columns, and a previously written hypothesis text, carefully check if the hypothesis text is consistent with \
the pandas expression or not. If it is consistent, simply return the hypothesis as it is. If it is not consistent, \
provide a new natural language hypothesis that is consistent with the pandas expression using only the provided \
information. Be specific.
Here's the information:
```json
{{
"target_column": "{target}",
"pandas_expression": "{expr}",
"existing_columns": {json.dumps(cols, indent=4)}
"old_hypothesis": "{old}",
}}```
Give your answer as a new JSON with the following format:
```json
{{
"hypothesis": "..."
}}"""
response = get_response(client, prompt, model=model, verbose=verbose)
return response
def replace_variable(client, expr, old, new, model='gpt-3.5-turbo', verbose=False):
prompt = f"""\
Given a pandas "expression", replace mentions of the "old" column with its "new" value such that the resultant \
expression is equivalent to the original expression.
Here's the information:
```json
{{
"expression": "{expr}",
"old": "{old}",
"new": "{new}"
}}```
Give your answer as a new JSON with the following format:
```json
{{
"new_expression": "..."
}}"""
response = get_response(client, prompt, model=model, verbose=verbose)
return response
@@ -0,0 +1,151 @@
common_hypothesis_features = [
'1-2 sentences',
'surprising finding',
'includes numeric concepts',
'includes categorical concepts',
'includes binary concepts',
]
hypothesis_features = [
['requires within-cluster analysis'],
['requires across-cluster analysis'],
['corresponds to a polynomial relationship of some columns'],
['corresponds to a ratio between some columns'],
['requires temporal analysis'],
['relationship is based on descriptive statistics of some columns'],
['requires concepts based on percentage or percentiles'],
['relationship is only applicable to one cluster in the data and not the others'],
]
column_features = [
[
'must have one target column',
'must have quantifiable columns',
'must have a few categorical columns',
'make sure the categorical column values do not contain special characters',
'include a few distractor columns',
]
]
common_pandas_features = [
'must be executable using python `eval` to create the target column in variable `df` (pandas dataframe)',
"for e.g., df['A']**2 + 3*df['B'] + 9, np.where(df['A'] > 3, 'Yes', 'No'), etc.",
'variables in pandas_expression must be from the existing columns listed above',
'variables in pandas_expression must NOT contain the target column itself',
]
pandas_features = [
['expression is a quadratic polynomial'],
['expression is a cubic polynomial'],
['expression is a ratio of existing columns'],
['expression is derived through logical combination of existing columns'],
# workflow
]
pandas_features = [common_pandas_features + p for p in pandas_features]
common_derived_features = [
'1-2 sentences',
'includes numeric concepts',
'includes categorical concepts',
'includes binary concepts',
]
derived_features = [common_derived_features + h for h in hypothesis_features]
hypothesis_features = [common_hypothesis_features + h for h in hypothesis_features]
PROMPT_HYP = """\
Given a dataset topic and description, generate an interesting hypothesis based on \
the provided instructions. Be creative and come up with an unusual finding.
```json
{
"topic": "%s",
"description": "%s",
"hypothesis_features": %s,
"hypothesis": "..."
}```
Give your answer as a new JSON with the following format:
```json
{
"hypothesis": "..."
}
```"""
PROMPT_COL = """\
Given a dataset topic, its description, and a true hypothesis that can be determined from it, \
generate a list of valid columns based on the provided instructions.
```json
{
"topic": "%s",
"description": "%s",
"hypothesis": "%s",
"column_instructions": %s,
"columns": [
{
"col_name": "...", # should be an "_"-separated string
"description": "...",
"data_type": "...", # should be executable using python's `eval` function. E.g., str, float, int, bool
"data_range": {...}, # should be either {"min": ..., "max": ...} or {"values": [...]}
"is_distractor": true/false, # boolean indicating whether this is a distractor that could cause confusion during data analysis
"is_target": true/false # boolean indicating whether this is the target variable for the hypothesis; at least one column should be the target
},
...
],
"pandas_instructions": %s,
"pandas_equation_for_hypothesis": {
"target_col": "...",
"target_col_type": "...",
"target_col_range": {...},
"independent_cols_in_pandas_expression": [], # list of column names that will be used to derive the target column
"pandas_expression": "..." # expression to derive df[target_col] using df[ind_col1], df[ind_col2], etc.
}
}```
Give your answer as a new JSON with the "columns" and "pandas_equation_for_hypothesis" keys filled using the following format:
```json
{
"columns": [...],
"pandas_equation_for_hypothesis": {...}
}
```"""
PROMPT_DER = """\
Given a dataset topic, description, a true hypothesis that can be determined from the data, \
and a target column from the dataset, generate a hypothesis for the target column using new independent columns not present in the existing columns.
```json
{
"topic": "%s",
"description": "%s",
"hypothesis": "%s",
"existing_columns": %s,
"target_column": "%s",
"new_to_target_instructions": %s,
"new_to_target_hypothesis": "...", # describe a relationship between new columns that explains the target column
"new_columns_for_target": [ # do not repeat any of the existing columns in the dataset
{
"col_name": "...", # should be an "_"-separated string
"description": "...",
"data_type": "...", # should be executable using python's `eval` function. E.g., str, float, int, bool
"data_range": {...}, # should be either {"min": ..., "max": ...} or {"values": [...]}
},
...
],
"pandas_instructions": %s,
"pandas_equation_for_new_to_target_hypothesis": {
"target_col": "...",
"target_col_type": "...",
"target_col_range": {...},
"independent_cols_in_pandas_expression": [], # list of column names from new_columns_for_target that will be used to derive target_col
"pandas_expression": "..." # expression to derive df[target_col] using df[ind_col1], df[ind_col2], etc.
}
}```
Give your answer as a new JSON with the "new_to_target_hypothesis", "new_columns_for_target", and \
"pandas_equation_for_new_to_target_hypothesis" keys filled using the following format:
```json
{
"new_to_target_hypothesis": "...",
"new_columns_for_target": [...],
"pandas_equation_for_new_to_target_hypothesis": {...}
}
```"""
@@ -0,0 +1,52 @@
workflow_summary_markers = [
'WORKFLOW SUMMARY',
'WORKFLOW_SUMMARY',
'WORKFLOW-SUMMARY',
'Workflow Summary',
]
final_answer_markers = [
'FINAL ANSWER',
'FINAL_ANSWER',
'FINAL-ANSWER',
'Final Answer',
'Scientific Hypothesis',
'Hypothesis',
]
next_agent_markers = [
'NEXT AGENT',
'NEXT-AGENT',
'NEXT_AGENT',
'FEEDBACK',
]
def extract_between(content, start_markers, end_markers=None):
for marker in start_markers:
if marker in content:
result = content.split(marker, 1)[1]
if end_markers:
for end_marker in end_markers:
if end_marker in result:
result = result.split(end_marker, 1)[0]
return result
return ''
def extract_gen_hypo_from_logs(content: str):
error = ''
gen_workflow = extract_between(
content, workflow_summary_markers, final_answer_markers
)
if not gen_workflow:
error += 'No Workflow Summary found in the line. | '
gen_hypothesis = extract_between(content, final_answer_markers, next_agent_markers)
if not gen_hypothesis:
error += 'No Final Answer in the line.'
return gen_hypothesis, gen_workflow, error
+491
View File
@@ -0,0 +1,491 @@
import asyncio
import json
import os
import git
import pandas as pd
from evaluation.discoverybench.eval_utils.eval_w_subhypo_gen import (
run_eval_gold_vs_gen_NL_hypo_workflow,
)
from evaluation.discoverybench.eval_utils.response_parser import (
extract_gen_hypo_from_logs,
)
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
)
from openhands.controller.state.state import State
from openhands.core.config import (
AgentConfig,
AppConfig,
SandboxConfig,
get_llm_config_arg,
parse_arguments,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
EVALUATION_LLM = 'gpt-4-1106-preview'
DATA_FILES = {}
LIBRARIES = [
'pandas',
'numpy',
'scipy',
'matplotlib',
'seaborn',
'scikit-learn',
'statsmodels',
]
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
}
AGENT_CLS_TO_INST_SUFFIX = {
'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n'
}
def get_config(
metadata: EvalMetadata,
) -> AppConfig:
config = AppConfig(
default_agent=metadata.agent_class,
run_as_openhands=False,
runtime='eventstream',
max_iterations=metadata.max_iterations,
sandbox=SandboxConfig(
base_container_image='python:3.12-bookworm',
enable_auto_lint=True,
use_host_network=False,
),
# do not mount workspace
workspace_base=None,
workspace_mount_path=None,
)
config.set_llm_config(metadata.llm_config)
agent_config = AgentConfig(
function_calling=False,
codeact_enable_jupyter=True,
codeact_enable_browsing_delegate=True,
)
config.set_agent_config(agent_config)
return config
def get_dv_query_for_real(
datasets, question, domain_knowledge=None, workflow_tags=None
):
"""
Prepare a structured query for the agent to execute on the specified datasets.
This function constructs a query by compiling metadata from the provided datasets, along with any relevant domain knowledge and workflow tags.
Args:
datasets: List of datasets
question: Query to be answered
domain_knowledge: Domain knowledge if any
workflow_tags: Workflow tags if any
Returns:
query_to_dv: Query to be run on the dataset
dataset_meta: Metadata of the dataset
"""
dataset_meta = ''
for dataset_metadata in datasets:
dataset_meta += 'Dataset name: ' + dataset_metadata['name']
dataset_meta += 'Dataset description: ' + dataset_metadata['description']
dataset_meta += '\nBrief description of columns: '
for col in dataset_metadata['columns']['raw']:
dataset_meta += col['name'] + ': ' + col['description'] + ', '
query_to_dv = dataset_meta
query_to_dv += f'\nQuery: {question}'
if domain_knowledge:
query_to_dv += (
'\nAdditionally, we provide some hints that might be useful to solve the task. Domain Knowledge: \n'
+ domain_knowledge
+ '.\n'
)
if workflow_tags:
query_to_dv += 'The meta tags are: ' + workflow_tags + '.\n'
query_to_dv += (
'In the final answer, please write down a scientific hypothesis in '
'natural language, derived from the provided dataset, clearly stating the '
'context of hypothesis (if any), variables chosen (if any) and '
'relationship between those variables (if any) including any statistical significance.'
'Also generate a summary of the full workflow starting from data loading that led to the final answer as WORKFLOW SUMMARY:'
)
# Run the NL query through datavoyager
return query_to_dv, dataset_meta
def initialize_runtime(runtime: Runtime, data_files: list[str]):
"""
Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent.
"""
logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
obs: CmdOutputObservation
action = CmdRunAction(command='mkdir -p /workspace')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
assert obs.exit_code == 0
action = CmdRunAction(command='cd /workspace')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
assert obs.exit_code == 0
for file in data_files:
runtime.copy_to(
file,
'/workspace',
)
for lib in LIBRARIES:
action = CmdRunAction(command=f'pip install {lib}')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
assert obs.exit_code == 0
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
def get_last_agent_finish_action(state: State) -> AgentFinishAction:
for event in state.history.get_events(reverse=True):
if isinstance(event, AgentFinishAction):
return event
return None
def get_last_message_action(state: State) -> MessageAction:
for event in state.history.get_events(reverse=True):
if isinstance(event, MessageAction):
return event
return None
def complete_runtime(state: State):
last_agent_finish_action = get_last_agent_finish_action(state)
last_agent_message_action = get_last_message_action(state)
if last_agent_finish_action is not None:
final_message_1 = last_agent_finish_action.thought
gen_hypo_1, gen_workflow_1, error_1 = extract_gen_hypo_from_logs(
final_message_1
)
else:
gen_hypo_1, gen_workflow_1, error_1 = '', '', ''
if last_agent_message_action is not None:
final_message_2 = last_agent_message_action.content
gen_hypo_2, gen_workflow_2, error_2 = extract_gen_hypo_from_logs(
final_message_2
)
else:
gen_hypo_2, gen_workflow_2, error_2 = '', '', ''
if gen_hypo_1 and gen_hypo_2:
test_result = {
'gen_hypo': last_agent_finish_action.thought
if last_agent_finish_action
else last_agent_message_action.content,
'gen_workflow': '',
'error': '',
}
return test_result
test_result = {
'gen_hypo': gen_hypo_1 if gen_hypo_1 else gen_hypo_2,
'gen_workflow': gen_workflow_1 if gen_workflow_1 else gen_workflow_2,
'error': error_1 if error_1 else error_2,
}
return test_result
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
):
"""
Process and evaluate a single instance of the dataset.
This function executes the OpenHands agent
for a specific instance of the dataset. It retrieves
the agent's results and evaluates them against the gold
hypothesis.
Args:
instance: A single row of the dataset
metadata: Metadata for the evaluation
reset_logger: Whether to reset the logger
Returns:
output: EvalOutput object
"""
config = get_config(metadata)
# use a session id for concurrent evaluation
sid = 'ID_' + str(instance.instance_id)
# Setup the logger properly, so you can run
# multi-processing to parallelize the evaluation
if reset_logger:
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
else:
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
problem_statement, dataset_metadata = get_dv_query_for_real(
datasets=instance.datasets,
question=instance.query,
domain_knowledge=instance.domain_knowledge,
workflow_tags=instance.workflow_tags,
)
# Prepare instruction
instruction = (
f'You are a discovery agent who can execute a python code only once to answer a query based on one or more datasets. The datasets will be present in the current directory.\n\n'
'Environment has been set up for you to start working. You may assume all necessary tools and datasets are installed.\n\n'
'# Problem Statement\n'
f'{problem_statement}\n\n'
)
instruction += (
'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
'You should NOT modify any existing test case files. If needed, you can add new test cases in a NEW file to reproduce the issue.\n'
'You SHOULD INCLUDE PROPER INDENTATION in your edit commands.\n'
)
# NOTE: You can actually set slightly different instruction for different agents
instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
# Here's how you can run the agent (similar to the `main` function) and get the final task state
runtime = create_runtime(config, sid=sid)
call_async_from_sync(runtime.connect)
initialize_runtime(runtime, instance.data_files)
state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
),
)
)
if state is None:
raise ValueError('State should not be None.')
metrics = state.metrics.get() if state.metrics else None
test_result = complete_runtime(state)
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
# DiscoveryBench Evaluation
eval_rec = run_eval_gold_vs_gen_NL_hypo_workflow(
query=instance.query,
gold_hypo=instance.gold_hypo,
gold_workflow='',
gen_hypo=test_result['gen_hypo'],
gen_workflow='',
dataset_meta=instance.dataset_metadata,
llm_used=EVALUATION_LLM,
dataset_type='real',
)
test_result['eval_rec'] = eval_rec
output = EvalOutput(
instance_id=str(instance.instance_id),
instruction=instruction,
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
test_result=test_result,
)
return output
def update_csv_name(name):
name = name.replace('-', '_')
if 'meta_regression' in name:
name = name.replace('meta_regression', 'meta-regression')
if 'ML_enabled' in name:
name = name.replace('ML_enabled', 'ML-enabled')
return name
def list_csv_files(list_of_datasets):
res = []
for ele in list_of_datasets:
for key, value in ele.items():
if key == 'name':
csv_file_name = update_csv_name(value)
res.append(DATA_FILES[csv_file_name])
return res
def create_dataset(repo_location: str, split: str = 'test'):
"""
Create a dataset from the discoverybench repository
by walking through the repository and extracting metadata
from the metadata_{}.json files
Args:
repo_location: Location of the repository
split: Split of the dataset to use
Returns:
df: DataFrame containing the dataset instances
"""
data_dict = {}
data_location = os.path.join(repo_location, 'discoverybench', 'real', split)
answer_key_location = os.path.join(repo_location, 'eval', 'answer_key_real.csv')
idx = 0
for root, dirs, files in os.walk(data_location):
for file in files:
if file.endswith('.json'):
if 'metadata' in file:
metadata = json.load(open(os.path.join(root, file)))
dataset = root.split('/')[-1]
metadata_id = file.split('_')[-1].split('.')[0]
domain = metadata.get('domain', '')
domain_knowledge = metadata.get('domain_knowledge', '')
workflow_tags = metadata.get('workflow_tags', '')
datasets = metadata.get('datasets', [])
queries = metadata.get('queries', [])
gold_workflow = metadata.get('workflow')
# loop through queries list to get queries
# and each query has qid; add that to dictionary
for query in queries[0]:
qid = query.get('qid', '')
data = {
'dataset': dataset,
'metadata_id': metadata_id,
'qid': qid,
'domain': domain,
'domain_knowledge': domain_knowledge,
'workflow_tags': workflow_tags,
'datasets': datasets,
'question_type': query['question_type'],
'query': query['question'],
'gold_workflow': gold_workflow,
'dataset_metadata': metadata,
}
data_dict[idx] = data
idx += 1
if file.endswith('.csv'):
DATA_FILES[file] = os.path.join(root, file)
if file.endswith('.txt'):
DATA_FILES[file] = os.path.join(root, file)
df = pd.DataFrame.from_dict(data_dict, orient='index')
df['instance_id'] = df.index
df['data_files'] = df['datasets'].apply(lambda x: list_csv_files(x))
answer_key = pd.read_csv(answer_key_location)
answer_key = answer_key.rename(
columns={
'metadataid': 'metadata_id',
'query_id': 'qid',
'gold_hypothesis': 'gold_hypothesis',
}
)
df['qid'] = df['qid'].astype(int)
df['metadata_id'] = df['metadata_id'].astype(int)
answer_key['qid'] = answer_key['qid'].astype(int)
answer_key['metadata_id'] = answer_key['metadata_id'].astype(int)
df = pd.merge(df, answer_key, on=['dataset', 'metadata_id', 'qid'], how='left')
return df
if __name__ == '__main__':
args = parse_arguments()
# clone git repositor for csv files
repo_url = 'https://github.com/allenai/discoverybench.git'
repo_location = 'git-discoverybench-allenai'
try:
git.Repo.clone_from(repo_url, repo_location)
except git.exc.GitCommandError:
print('Repository already exists')
dataset = create_dataset(repo_location)
# check if there is any empty csv_file
if dataset['data_files'].isnull().any():
raise ValueError('Some csv files are missing.')
llm_config = None
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
if llm_config is None:
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
metadata = make_metadata(
llm_config,
'discoverybench-python',
args.agent_cls,
args.max_iterations,
args.eval_note,
args.eval_output_dir,
)
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
run_evaluation(
instances,
metadata,
output_file,
args.eval_num_workers,
process_instance,
)
+46
View File
@@ -0,0 +1,46 @@
#!/bin/bash
set -eo pipefail
source "evaluation/utils/version_control.sh"
MODEL_CONFIG=$1
COMMIT_HASH=$2
AGENT=$3
EVAL_LIMIT=$4
NUM_WORKERS=$5
if [ -z "$NUM_WORKERS" ]; then
NUM_WORKERS=1
echo "Number of workers not specified, use default $NUM_WORKERS"
fi
# ################################################################################
checkout_eval_branch
if [ -z "$AGENT" ]; then
echo "Agent not specified, use default CodeActAgent"
AGENT="CodeActAgent"
fi
get_agent_version
echo "AGENT: $AGENT"
echo "AGENT_VERSION: $AGENT_VERSION"
echo "MODEL_CONFIG: $MODEL_CONFIG"
COMMAND="poetry run python evaluation/discoverybench/run_infer.py \
--agent-cls $AGENT \
--llm-config $MODEL_CONFIG \
--max-iterations 10 \
--max-chars 10000000 \
--eval-num-workers $NUM_WORKERS \
--eval-note $AGENT_VERSION"
if [ -n "$EVAL_LIMIT" ]; then
echo "EVAL_LIMIT: $EVAL_LIMIT"
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
fi
# Run the command
eval $COMMAND
+6 -9
View File
@@ -13,6 +13,7 @@ from evaluation.utils.shared import (
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from openhands.controller.state.state import State
from openhands.core.config import (
@@ -55,18 +56,14 @@ def get_config(
workspace_base=None,
workspace_mount_path=None,
)
if metadata.llm_config.log_completions:
metadata.llm_config.log_completions_folder = os.path.join(
metadata.eval_output_dir, 'llm_completions', instance_id
config.set_llm_config(
update_llm_config_for_completions_logging(
metadata.llm_config, metadata.eval_output_dir, instance_id
)
logger.info(
f'Logging LLM completions for instance {instance_id} to '
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
)
agent_config = AgentConfig(
codeact_enable_jupyter=True,
codeact_enable_browsing_delegate=True,
codeact_enable_browsing=True,
codeact_enable_llm_editor=False,
)
config.set_agent_config(agent_config)
@@ -0,0 +1,44 @@
from evaluation.integration_tests.tests.base import BaseIntegrationTest, TestResult
from openhands.events.action import AgentFinishAction, MessageAction
from openhands.events.event import Event
from openhands.events.observation import AgentDelegateObservation
from openhands.runtime.base import Runtime
class Test(BaseIntegrationTest):
INSTRUCTION = 'Look at https://github.com/All-Hands-AI/OpenHands/pull/8, and tell me what is happening there and what did @asadm suggest.'
@classmethod
def initialize_runtime(cls, runtime: Runtime) -> None:
pass
@classmethod
def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult:
# check if the "The answer is OpenHands is all you need!" is in any message
message_actions = [
event
for event in histories
if isinstance(
event, (MessageAction, AgentFinishAction, AgentDelegateObservation)
)
]
for event in message_actions:
if isinstance(event, AgentDelegateObservation):
content = event.content
elif isinstance(event, AgentFinishAction):
content = event.outputs.get('content', '')
elif isinstance(event, MessageAction):
content = event.content
else:
raise ValueError(f'Unknown event type: {type(event)}')
if (
'non-commercial' in content
or 'MIT' in content
or 'Apache 2.0' in content
):
return TestResult(success=True)
return TestResult(
success=False,
reason=f'The answer is not found in any message. Total messages: {len(message_actions)}. Messages: {message_actions}',
)
+39 -9
View File
@@ -10,10 +10,12 @@ import pandas as pd
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from openhands.controller.state.state import State
from openhands.core.config import (
@@ -29,7 +31,10 @@ from openhands.events.action import (
CmdRunAction,
MessageAction,
)
from openhands.events.observation import CmdOutputObservation
from openhands.events.observation import (
BrowserOutputObservation,
CmdOutputObservation,
)
from openhands.runtime.base import Runtime
from openhands.runtime.browser.browser_env import (
BROWSER_EVAL_GET_GOAL_ACTION,
@@ -37,7 +42,11 @@ from openhands.runtime.browser.browser_env import (
)
from openhands.utils.async_utils import call_async_from_sync
SUPPORTED_AGENT_CLS = {'BrowsingAgent'}
SUPPORTED_AGENT_CLS = {'BrowsingAgent', 'CodeActAgent'}
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
}
def get_config(
@@ -47,25 +56,32 @@ def get_config(
config = AppConfig(
default_agent=metadata.agent_class,
run_as_openhands=False,
runtime='eventstream',
runtime=os.environ.get('RUNTIME', 'eventstream'),
max_iterations=metadata.max_iterations,
sandbox=SandboxConfig(
base_container_image='xingyaoww/od-eval-miniwob:v1.0',
enable_auto_lint=True,
use_host_network=False,
browsergym_eval_env=env_id,
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
),
# do not mount workspace
workspace_base=None,
workspace_mount_path=None,
)
config.set_llm_config(metadata.llm_config)
config.set_llm_config(
update_llm_config_for_completions_logging(
metadata.llm_config, metadata.eval_output_dir, env_id
)
)
return config
def initialize_runtime(
runtime: Runtime,
) -> str:
) -> tuple[str, BrowserOutputObservation]:
"""Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent.
@@ -85,8 +101,14 @@ def initialize_runtime(
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
goal = obs.content
# Run noop to get the initial browser observation (e.g., the page URL & content)
action = BrowseInteractiveAction(browser_actions='noop(1000)')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
return goal
return goal, obs
def complete_runtime(
@@ -117,7 +139,7 @@ def process_instance(
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
env_id = instance.id
env_id = instance.instance_id
config = get_config(metadata, env_id)
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
@@ -129,7 +151,12 @@ def process_instance(
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
task_str = initialize_runtime(runtime)
task_str, obs = initialize_runtime(runtime)
task_str += (
f'\nInitial browser state (output of `noop(1000)`):\n{obs.get_agent_obs_text()}'
)
state: State | None = asyncio.run(
run_controller(
config=config,
@@ -137,6 +164,9 @@ def process_instance(
content=task_str
), # take output from initialize_runtime
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
],
)
)
@@ -159,7 +189,7 @@ def process_instance(
return_val = complete_runtime(runtime)
logger.info(f'Return value from complete_runtime: {return_val}')
reward = max(return_val['rewards'])
reward = max(return_val['rewards'], default=0)
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
+17
View File
@@ -0,0 +1,17 @@
FROM python:3.11-bookworm
# For OpenHands agents to explore the dataset directories, please download the full benchmark [here](https://buckeyemailosu-my.sharepoint.com/:u:/g/personal/chen_8336_buckeyemail_osu_edu/EQuA6uJ3CtRHvRfZ2GiN1tYBRVJE4DSUD10MW61fr7HuSQ?e=sCBegG) and unzip it with password `scienceagentbench`.
# **Please DO NOT redistribute the unzipped data files online.**
# It will download a benchmark.zip file to the current directory.
# unzip it and put the benchmark folder under evaluation/scienceagentbench/
RUN mkdir -p /benchmark
COPY benchmark /benchmark
RUN mkdir -p /workspace
WORKDIR /workspace
# pushd evaluation/scienceagentbench
# docker build -t xingyaoww/openhands-eval-scienceagentbench .
# popd
@@ -0,0 +1,25 @@
FROM mambaorg/micromamba:debian12
USER root
# For https://github.com/OSU-NLP-Group/ScienceAgentBench/tree/main?tab=readme-ov-file#code-generation-with-agents
RUN micromamba create -n sci-agent-eval python=3.10 pip setuptools wheel
RUN micromamba run -n sci-agent-eval pip install pip-tools
RUN mkdir -p /workspace
WORKDIR /workspace
RUN apt-get update && apt-get install -y git
RUN git clone https://github.com/OSU-NLP-Group/ScienceAgentBench.git /workspace/
RUN git checkout 4eddc7db6449a5ade3e37285747c8b208cd54ce7
RUN micromamba create -n sci-agent python=3.10 pip setuptools wheel
RUN micromamba run -n sci-agent pip install -r requirements.txt
# Replace all occurence of conda with micromamba under the /workspace
RUN find ./ -type f -exec sed -i 's/conda/micromamba/g' {} \;
# pushd evaluation/scienceagentbench
# docker build -t xingyaoww/openhands-eval-scienceagentbench-evaluator -f Dockerfile.evaluator .
# popd
+54
View File
@@ -0,0 +1,54 @@
# ScienceAgentBench Evaluation with OpenHands
This folder contains the evaluation harness for [ScienceAgentBench](https://osu-nlp-group.github.io/ScienceAgentBench/) (paper: https://arxiv.org/abs/2410.05080).
## Setup Environment and LLM Configuration
Please follow instruction [here](../README.md#setup) to setup your local development environment and LLM.
## Setup ScienceAgentBench
To prevent benchmark data contamination, we only provide the annotation sheet on [Huggingface](https://huggingface.co/datasets/osunlp/ScienceAgentBench), which includes all necessary *inputs* to run an agent.
## Run Inference on ScienceAgentBench
```bash
./evaluation/scienceagentbench/scripts/run_infer.sh [model_config] [git-version] [use_knowledge] [agent] [eval_limit] [max_iter] [num_workers] [dataset] [dataset_split]
# Example
./evaluation/scienceagentbench/scripts/run_infer.sh llm.eval_gpt4o 0.9.3
```
where `model_config` is mandatory, and the rest are optional.
- `model_config`, e.g. `eval_gpt4_1106_preview`, is the config group name for your
LLM settings, as defined in your `config.toml`.
- `git-version`, e.g. `HEAD`, is the git commit hash of the OpenHands version you would
like to evaluate. It could also be a release tag like `0.6.2`.
- `use_knowledge`, e.g. `true`, specifies whether allowing the agent to use expert-provided knowledge as additional input or not. By default, it is set to `false`.
- `agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting
to `CodeActAgent`.
- `eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances. By
default, the script evaluates the entire SWE-bench_Lite test set (300 issues). Note:
in order to use `eval_limit`, you must also set `agent`.
- `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By
default, it is set to 30.
- `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By
default, it is set to 1.
## Evaluate Generated Programs
### Extract Necessary Information from OpenHands Log
After the inference is completed, you may use the following command to extract necessary information from the output log for evaluation:
```bash
python post_proc.py [log_fname]
```
- `log_fname`, e.g. `evaluation/.../output.jsonl`, is the automatically saved trajectory log of an OpenHands agent.
Output will be write to e.g. `evaluation/.../output.converted.jsonl`
### Run evaluation
Please follow the steps [here](https://github.com/OSU-NLP-Group/ScienceAgentBench/tree/main?tab=readme-ov-file#evaluation-of-generated-code) to evaluate the generated programs.
+30
View File
@@ -0,0 +1,30 @@
import json
from argparse import ArgumentParser
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
'log_fname',
type=str,
)
args = parser.parse_args()
fname = args.log_fname
out_fname = args.log_fname.replace('.jsonl', '.converted.jsonl')
log = [json.loads(line) for line in open(fname)]
simple_log = [
json.dumps(
{
'instance_id': ex['instance_id'],
'instruction': ex['instruction'],
'test_result': ex['test_result'],
'cost': ex['metrics']['accumulated_cost'],
}
)
for ex in log
]
with open(out_fname, 'w+', encoding='utf-8') as f:
f.write('\n'.join(simple_log))
+291
View File
@@ -0,0 +1,291 @@
import asyncio
import os
from typing import Any
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from openhands.controller.state.state import State
from openhands.core.config import (
AppConfig,
SandboxConfig,
get_llm_config_arg,
get_parser,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
}
LOCAL_DATASET_PATH = os.path.join(os.path.dirname(__file__), 'benchmark')
def format_task_dict(example, use_knowledge):
task = {
'instance_id': example['instance_id'],
'task_inst': example['task_inst'],
'dataset_path': '/benchmark/datasets/'
+ example['dataset_folder_tree'].split('\n')[0][4:],
'dataset_folder_tree': example['dataset_folder_tree'],
'dataset_preview': example['dataset_preview'],
'pred_program_name': 'pred_' + example['gold_program_name'],
}
if use_knowledge:
task['task_inst'] += '\n' + str(example['domain_knowledge'])
return task
def get_config(
metadata: EvalMetadata,
instance_id: str,
) -> AppConfig:
config = AppConfig(
default_agent=metadata.agent_class,
run_as_openhands=False,
runtime=os.environ.get('RUNTIME', 'eventstream'),
max_budget_per_task=4,
max_iterations=metadata.max_iterations,
sandbox=SandboxConfig(
base_container_image='docker.io/xingyaoww/openhands-eval-scienceagentbench',
enable_auto_lint=True,
use_host_network=False,
timeout=300,
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
),
# do not mount workspace
workspace_base=None,
workspace_mount_path=None,
)
config.set_llm_config(
update_llm_config_for_completions_logging(
metadata.llm_config,
metadata.eval_output_dir,
instance_id,
)
)
return config
def initialize_runtime(
runtime: Runtime,
instance: pd.Series, # this argument is not required
):
"""Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent.
"""
logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
obs: CmdOutputObservation
# Set up workspace directories
action = CmdRunAction(command='mkdir -p /workspace/pred_programs')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
assert obs.exit_code == 0
action = CmdRunAction(command='mkdir -p /workspace/pred_results')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
assert obs.exit_code == 0
dataset_name = instance['dataset_folder_tree'].split('\n')[0][4:].rstrip('/')
# Copy the dataset to the workspace
dataset_dir = os.path.join(
LOCAL_DATASET_PATH,
'datasets',
dataset_name,
)
runtime.copy_to(dataset_dir, '/workspace/benchmark/datasets', recursive=True)
# Check the dataset exists
action = CmdRunAction(
command='cd /workspace/benchmark/datasets && ls',
keep_prompt=False,
)
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
assert dataset_name in obs.content
logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
def complete_runtime(
runtime: Runtime,
instance: pd.Series,
) -> dict[str, Any]:
"""Complete the runtime for the agent.
This function is called before the runtime is used to run the agent.
If you need to do something in the sandbox to get the correctness metric after
the agent has run, modify this function.
"""
logger.info(f"{'-' * 50} BEGIN Runtime Completion Fn {'-' * 50}")
obs: CmdOutputObservation
test_result = {}
action = CmdRunAction(command='cd /workspace')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
assert obs.exit_code == 0
action = CmdRunAction(
command=f'cat pred_programs/{instance.pred_program_name}',
keep_prompt=False,
)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
if obs.exit_code == 0:
test_result = {'program': obs.content}
else:
test_result = {'program': 'ERROR'}
logger.info(f"{'-' * 50} END Runtime Completion Fn {'-' * 50}")
return test_result
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
instance_id = instance.instance_id.replace('/', '__')
config = get_config(metadata, instance_id)
# Set up the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
reset_logger_for_multiprocessing(logger, instance_id, log_dir)
else:
logger.info(f'Starting evaluation for instance {instance_id}.')
instruction = f"""You are an expert Python programming assistant that helps scientist users to write high-quality code to solve their tasks.
Given a user request, you are expected to write a complete program that accomplishes the requested task and save any outputs to `/workspace/pred_results/` in the correct format.
Here's the user request you need to work on:
{instance.task_inst}
You can access the dataset at `{instance.dataset_path}`. Here is the directory structure of the dataset:
```
{instance.dataset_folder_tree}
```
Here are some helpful previews for the dataset file(s):
{instance.dataset_preview}
Please save your program as `/workspace/pred_programs/{instance.pred_program_name}`.
Then, please run the program to check and fix any errors.
Please do NOT run the program in the background.
If the program uses some packages that are incompatible, please figure out alternative implementations and do NOT restart the environment.
"""
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
initialize_runtime(runtime, instance)
# Here's how you can run the agent (similar to the `main` function) and get the final task state
state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
),
)
)
# ======= Attempt to evaluate the agent's edits =======
test_result = complete_runtime(runtime, instance)
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
if state is None:
raise ValueError('State should not be None.')
metrics = state.metrics.get() if state.metrics else None
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
# Save the output
output = EvalOutput(
instance_id=instance.instance_id,
instruction=instruction,
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
test_result=test_result,
)
return output
if __name__ == '__main__':
parser = get_parser()
parser.add_argument(
'--use_knowledge',
type=str,
default='false',
choices=['true', 'false'],
help='use expert-provided knowledge or not',
)
args, _ = parser.parse_known_args()
sab_dataset = load_dataset('osunlp/ScienceAgentBench', split='validation')
dataset_processed = []
for example in tqdm(sab_dataset):
dataset_processed.append(
format_task_dict(example, args.use_knowledge == 'true')
)
dataset = pd.DataFrame(dataset_processed)
llm_config = None
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
if llm_config is None:
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
metadata = make_metadata(
llm_config,
'ScienceAgentBench',
args.agent_cls,
args.max_iterations,
args.eval_note,
args.eval_output_dir,
)
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
dataset['instance_id'] = dataset['instance_id'].apply(str)
instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
run_evaluation(
instances, metadata, output_file, args.eval_num_workers, process_instance
)
+49
View File
@@ -0,0 +1,49 @@
#!/bin/bash
set -eo pipefail
source "evaluation/utils/version_control.sh"
MODEL_CONFIG=$1
COMMIT_HASH=$2
USE_KNOWLEDGE=$3
AGENT=$4
EVAL_LIMIT=$5
NUM_WORKERS=$6
if [ -z "$NUM_WORKERS" ]; then
NUM_WORKERS=1
echo "Number of workers not specified, use default $NUM_WORKERS"
fi
checkout_eval_branch
if [ -z "$AGENT" ]; then
echo "Agent not specified, use default CodeActAgent"
AGENT="CodeActAgent"
fi
if [ -z "$USE_KNOWLEDGE" ]; then
echo "Use knowledge not specified, use default False"
USE_KNOWLEDGE=false
fi
get_agent_version
echo "AGENT: $AGENT"
echo "AGENT_VERSION: $AGENT_VERSION"
echo "MODEL_CONFIG: $MODEL_CONFIG"
COMMAND="poetry run python evaluation/scienceagentbench/run_infer.py \
--agent-cls $AGENT \
--llm-config $MODEL_CONFIG \
--use_knowledge $USE_KNOWLEDGE \
--max-iterations 30 \
--eval-num-workers $NUM_WORKERS \
--eval-note $AGENT_VERSION" \
if [ -n "$EVAL_LIMIT" ]; then
echo "EVAL_LIMIT: $EVAL_LIMIT"
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
fi
# Run the command
eval $COMMAND
+1 -1
View File
@@ -239,7 +239,7 @@ def process_instance(
# Create a directory structure that matches the expected format
# NOTE: this is a hack to make the eval report format consistent
# with the original SWE-Bench eval script
log_dir = os.path.join(temp_dir, 'logs', instance_id)
log_dir = os.path.join(temp_dir, 'logs', instance_id.lower())
os.makedirs(log_dir, exist_ok=True)
test_output_path = os.path.join(log_dir, 'test_output.txt')
with open(test_output_path, 'w') as f:
+15 -10
View File
@@ -20,6 +20,7 @@ from evaluation.utils.shared import (
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from openhands.controller.state.state import State
from openhands.core.config import (
@@ -40,6 +41,7 @@ from openhands.utils.async_utils import call_async_from_sync
USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true'
RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'true'
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
@@ -88,6 +90,13 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
'5. Think about edgecases and make sure your fix handles them as well\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
)
if RUN_WITH_BROWSING:
instruction += (
'<IMPORTANT!>\n'
'You SHOULD NEVER attempt to browse the web. '
'</IMPORTANT!>\n'
)
return instruction
@@ -101,7 +110,7 @@ def get_instance_docker_image(instance_id: str) -> str:
image_name = image_name.replace(
'__', '_s_'
) # to comply with docker image naming convention
return DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name
return (DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name).lower()
def get_config(
@@ -142,18 +151,14 @@ def get_config(
workspace_base=None,
workspace_mount_path=None,
)
if metadata.llm_config.log_completions:
metadata.llm_config.log_completions_folder = os.path.join(
metadata.eval_output_dir, 'llm_completions', instance['instance_id']
config.set_llm_config(
update_llm_config_for_completions_logging(
metadata.llm_config, metadata.eval_output_dir, instance['instance_id']
)
logger.info(
f'Logging LLM completions for instance {instance["instance_id"]} to '
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
)
agent_config = AgentConfig(
codeact_enable_jupyter=False,
codeact_enable_browsing_delegate=False,
codeact_enable_browsing=RUN_WITH_BROWSING,
codeact_enable_llm_editor=False,
)
config.set_agent_config(agent_config)
+11
View File
@@ -34,6 +34,11 @@ if [ -z "$USE_INSTANCE_IMAGE" ]; then
USE_INSTANCE_IMAGE=true
fi
if [ -z "$RUN_WITH_BROWSING" ]; then
echo "RUN_WITH_BROWSING not specified, use default false"
RUN_WITH_BROWSING=false
fi
if [ -z "$DATASET" ]; then
echo "DATASET not specified, use default princeton-nlp/SWE-bench_Lite"
@@ -47,6 +52,8 @@ fi
export USE_INSTANCE_IMAGE=$USE_INSTANCE_IMAGE
echo "USE_INSTANCE_IMAGE: $USE_INSTANCE_IMAGE"
export RUN_WITH_BROWSING=$RUN_WITH_BROWSING
echo "RUN_WITH_BROWSING: $RUN_WITH_BROWSING"
get_agent_version
@@ -67,6 +74,10 @@ if [ "$USE_HINT_TEXT" = false ]; then
EVAL_NOTE="$EVAL_NOTE-no-hint"
fi
if [ "$RUN_WITH_BROWSING" = true ]; then
EVAL_NOTE="$EVAL_NOTE-with-browsing"
fi
if [ -n "$EXP_NAME" ]; then
EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
fi
+17
View File
@@ -411,3 +411,20 @@ def reset_logger_for_multiprocessing(
)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
def update_llm_config_for_completions_logging(
llm_config: LLMConfig,
eval_output_dir: str,
instance_id: str,
) -> LLMConfig:
"""Update the LLM config for logging completions."""
if llm_config.log_completions:
llm_config.log_completions_folder = os.path.join(
eval_output_dir, 'llm_completions', instance_id
)
logger.info(
f'Logging LLM completions for instance {instance_id} to '
f'{llm_config.log_completions_folder}'
)
return llm_config
@@ -5,7 +5,6 @@ import { FeedbackForm } from "#/components/feedback-form";
describe("FeedbackForm", () => {
const user = userEvent.setup();
const onSubmitMock = vi.fn();
const onCloseMock = vi.fn();
afterEach(() => {
@@ -13,7 +12,7 @@ describe("FeedbackForm", () => {
});
it("should render correctly", () => {
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
render(<FeedbackForm polarity="positive" onClose={onCloseMock} />);
screen.getByLabelText("Email");
screen.getByLabelText("Private");
@@ -24,7 +23,7 @@ describe("FeedbackForm", () => {
});
it("should switch between private and public permissions", async () => {
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
render(<FeedbackForm polarity="positive" onClose={onCloseMock} />);
const privateRadio = screen.getByLabelText("Private");
const publicRadio = screen.getByLabelText("Public");
@@ -40,69 +39,11 @@ describe("FeedbackForm", () => {
expect(publicRadio).not.toBeChecked();
});
it("should call onSubmit when the form is submitted", async () => {
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
const email = screen.getByLabelText("Email");
await user.type(email, "test@test.test");
await user.click(screen.getByRole("button", { name: "Submit" }));
expect(onSubmitMock).toHaveBeenCalledWith("private", "test@test.test"); // private is the default value
});
it("should not call onSubmit when the email is invalid", async () => {
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
const email = screen.getByLabelText("Email");
const submitButton = screen.getByRole("button", { name: "Submit" });
await user.click(submitButton);
expect(onSubmitMock).not.toHaveBeenCalled();
await user.type(email, "test");
await user.click(submitButton);
expect(onSubmitMock).not.toHaveBeenCalled();
});
it("should submit public permissions when the public radio is checked", async () => {
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
const email = screen.getByLabelText("Email");
const publicRadio = screen.getByLabelText("Public");
await user.type(email, "test@test.test");
await user.click(publicRadio);
await user.click(screen.getByRole("button", { name: "Submit" }));
expect(onSubmitMock).toHaveBeenCalledWith("public", "test@test.test");
});
it("should call onClose when the close button is clicked", async () => {
render(<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />);
render(<FeedbackForm polarity="positive" onClose={onCloseMock} />);
await user.click(screen.getByRole("button", { name: "Cancel" }));
expect(onSubmitMock).not.toHaveBeenCalled();
expect(onCloseMock).toHaveBeenCalled();
});
it("should disable the buttons if isSubmitting is true", () => {
const { rerender } = render(
<FeedbackForm onSubmit={onSubmitMock} onClose={onCloseMock} />,
);
const submitButton = screen.getByRole("button", { name: "Submit" });
const cancelButton = screen.getByRole("button", { name: "Cancel" });
expect(submitButton).not.toBeDisabled();
expect(cancelButton).not.toBeDisabled();
rerender(
<FeedbackForm
onSubmit={onSubmitMock}
onClose={onCloseMock}
isSubmitting
/>,
);
expect(submitButton).toBeDisabled();
expect(cancelButton).toBeDisabled();
});
});
@@ -16,13 +16,16 @@ vi.mock("../../services/fileService", async () => ({
}));
const renderFileExplorerWithRunningAgentState = () =>
renderWithProviders(<FileExplorer error={null} />, {
preloadedState: {
agent: {
curAgentState: AgentState.RUNNING,
renderWithProviders(
<FileExplorer error={null} isOpen onToggle={() => {}} />,
{
preloadedState: {
agent: {
curAgentState: AgentState.RUNNING,
},
},
},
});
);
describe.skip("FileExplorer", () => {
afterEach(() => {
+43 -2
View File
@@ -1,12 +1,12 @@
{
"name": "openhands-frontend",
"version": "0.11.0",
"version": "0.12.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "openhands-frontend",
"version": "0.11.0",
"version": "0.12.0",
"dependencies": {
"@monaco-editor/react": "^4.6.0",
"@nextui-org/react": "^2.4.8",
@@ -26,6 +26,7 @@
"isbot": "^5.1.17",
"jose": "^5.9.4",
"monaco-editor": "^0.52.0",
"posthog-js": "^1.176.0",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-highlight": "^0.15.0",
@@ -7864,6 +7865,16 @@
"node": ">=6.6.0"
}
},
"node_modules/core-js": {
"version": "3.38.1",
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.38.1.tgz",
"integrity": "sha512-OP35aUorbU3Zvlx7pjsFdu1rGNnD4pgw/CWoYzRY3t2EzoVT7shKHY1dlAy3f41cGIO7ZDPQimhGFTlEYkG/Hw==",
"hasInstallScript": true,
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/core-js"
}
},
"node_modules/core-util-is": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz",
@@ -9666,6 +9677,11 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/fflate": {
"version": "0.4.8",
"resolved": "https://registry.npmjs.org/fflate/-/fflate-0.4.8.tgz",
"integrity": "sha512-FJqqoDBR00Mdj9ppamLa/Y7vxm+PRmNWA67N846RvsoYVMKB4q3y/de5PA7gUmRMYK/8CMz2GDZQmCRN1wBcWA=="
},
"node_modules/file-entry-cache": {
"version": "6.0.1",
"resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz",
@@ -19653,6 +19669,31 @@
"resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz",
"integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ=="
},
"node_modules/posthog-js": {
"version": "1.176.0",
"resolved": "https://registry.npmjs.org/posthog-js/-/posthog-js-1.176.0.tgz",
"integrity": "sha512-T5XKNtRzp7q6CGb7Vc7wAI76rWap9fiuDUPxPsyPBPDkreKya91x9RIsSapAVFafwD1AEin1QMczCmt9Le9BWw==",
"dependencies": {
"core-js": "^3.38.1",
"fflate": "^0.4.8",
"preact": "^10.19.3",
"web-vitals": "^4.2.0"
}
},
"node_modules/posthog-js/node_modules/web-vitals": {
"version": "4.2.4",
"resolved": "https://registry.npmjs.org/web-vitals/-/web-vitals-4.2.4.tgz",
"integrity": "sha512-r4DIlprAGwJ7YM11VZp4R884m0Vmgr6EAKe3P+kO0PPj3Unqyvv59rczf6UiGcb9Z8QxZVcqKNwv/g0WNdWwsw=="
},
"node_modules/preact": {
"version": "10.24.3",
"resolved": "https://registry.npmjs.org/preact/-/preact-10.24.3.tgz",
"integrity": "sha512-Z2dPnBnMUfyQfSQ+GBdsGa16hz35YmLmtTLhM169uW944hYL6xzTYkJjC07j+Wosz733pMWx0fgON3JNw1jJQA==",
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/preact"
}
},
"node_modules/prelude-ls": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz",
+2 -1
View File
@@ -1,6 +1,6 @@
{
"name": "openhands-frontend",
"version": "0.11.0",
"version": "0.12.0",
"private": true,
"type": "module",
"engines": {
@@ -25,6 +25,7 @@
"isbot": "^5.1.17",
"jose": "^5.9.4",
"monaco-editor": "^0.52.0",
"posthog-js": "^1.176.0",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-highlight": "^0.15.0",
+28 -107
View File
@@ -1,4 +1,4 @@
import { getValidFallbackHost } from "#/utils/get-valid-fallback-host";
import { request } from "#/services/api";
import {
SaveFileSuccessResponse,
FileUploadSuccessResponse,
@@ -9,36 +9,13 @@ import {
GetConfigResponse,
} from "./open-hands.types";
/**
* Generate the base URL of the OpenHands API
* @returns Base URL of the OpenHands API
*/
const generateBaseURL = () => {
const fallback = getValidFallbackHost();
const baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || fallback;
if (typeof window === "undefined") {
return `http://${baseUrl}`;
}
return `${window.location.protocol}//${baseUrl}`;
};
/**
* Class to interact with the OpenHands API
*/
class OpenHands {
/**
* Base URL of the OpenHands API
*/
static BASE_URL = generateBaseURL();
/**
* Retrieve the list of models available
* @returns List of models available
*/
static async getModels(): Promise<string[]> {
const response = await fetch(`${OpenHands.BASE_URL}/api/options/models`);
return response.json();
return request("/api/options/models");
}
/**
@@ -46,8 +23,7 @@ class OpenHands {
* @returns List of agents available
*/
static async getAgents(): Promise<string[]> {
const response = await fetch(`${OpenHands.BASE_URL}/api/options/agents`);
return response.json();
return request(`/api/options/agents`);
}
/**
@@ -55,178 +31,123 @@ class OpenHands {
* @returns List of security analyzers available
*/
static async getSecurityAnalyzers(): Promise<string[]> {
const response = await fetch(
`${OpenHands.BASE_URL}/api/options/security-analyzers`,
);
return response.json();
return request(`/api/options/security-analyzers`);
}
static async getConfig(): Promise<GetConfigResponse> {
const response = await fetch("config.json", {
headers: {
"Cache-Control": "no-cache",
},
});
return response.json();
return request("/config.json");
}
/**
* Retrieve the list of files available in the workspace
* @param token User token provided by the server
* @param path Path to list files from
* @returns List of files available in the given path. If path is not provided, it lists all the files in the workspace
*/
static async getFiles(token: string, path?: string): Promise<string[]> {
const url = new URL(`${OpenHands.BASE_URL}/api/list-files`);
if (path) url.searchParams.append("path", path);
const response = await fetch(url.toString(), {
headers: {
Authorization: `Bearer ${token}`,
},
});
return response.json();
static async getFiles(path?: string): Promise<string[]> {
let url = "/api/list-files";
if (path) url += `?path=${encodeURIComponent(path)}`;
return request(url);
}
/**
* Retrieve the content of a file
* @param token User token provided by the server
* @param path Full path of the file to retrieve
* @returns Content of the file
*/
static async getFile(token: string, path: string): Promise<string> {
const url = new URL(`${OpenHands.BASE_URL}/api/select-file`);
url.searchParams.append("file", path);
const response = await fetch(url.toString(), {
headers: {
Authorization: `Bearer ${token}`,
},
});
const data = await response.json();
static async getFile(path: string): Promise<string> {
const url = `/api/select-file?file=${encodeURIComponent(path)}`;
const data = await request(url);
return data.code;
}
/**
* Save the content of a file
* @param token User token provided by the server
* @param path Full path of the file to save
* @param content Content to save in the file
* @returns Success message or error message
*/
static async saveFile(
token: string,
path: string,
content: string,
): Promise<SaveFileSuccessResponse | ErrorResponse> {
const response = await fetch(`${OpenHands.BASE_URL}/api/save-file`, {
return request(`/api/save-file`, {
method: "POST",
body: JSON.stringify({ filePath: path, content }),
headers: {
Authorization: `Bearer ${token}`,
"Content-Type": "application/json",
},
});
return response.json();
}
/**
* Upload a file to the workspace
* @param token User token provided by the server
* @param file File to upload
* @returns Success message or error message
*/
static async uploadFiles(
token: string,
file: File[],
): Promise<FileUploadSuccessResponse | ErrorResponse> {
const formData = new FormData();
file.forEach((f) => formData.append("files", f));
const response = await fetch(`${OpenHands.BASE_URL}/api/upload-files`, {
return request(`/api/upload-files`, {
method: "POST",
body: formData,
headers: {
Authorization: `Bearer ${token}`,
},
});
return response.json();
}
/**
* Get the blob of the workspace zip
* @param token User token provided by the server
* @returns Blob of the workspace zip
*/
static async getWorkspaceZip(token: string): Promise<Blob> {
const response = await fetch(`${OpenHands.BASE_URL}/api/zip-directory`, {
headers: {
Authorization: `Bearer ${token}`,
},
});
static async getWorkspaceZip(): Promise<Blob> {
const response = await request(`/api/zip-directory`, {}, false, true);
return response.blob();
}
/**
* Send feedback to the server
* @param token User token provided by the server
* @param data Feedback data
* @returns The stored feedback data
*/
static async sendFeedback(
token: string,
data: Feedback,
): Promise<FeedbackResponse> {
const response = await fetch(`${OpenHands.BASE_URL}/api/submit-feedback`, {
static async submitFeedback(data: Feedback): Promise<FeedbackResponse> {
return request(`/api/submit-feedback`, {
method: "POST",
body: JSON.stringify(data),
headers: {
Authorization: `Bearer ${token}`,
"Content-Type": "application/json",
},
});
return response.json();
}
/**
* Get the GitHub access token
* @param code Code provided by GitHub
* @returns GitHub access token
*/
static async getGitHubAccessToken(
code: string,
): Promise<GitHubAccessTokenResponse> {
const response = await fetch(`${OpenHands.BASE_URL}/api/github/callback`, {
return request(`/api/github/callback`, {
method: "POST",
body: JSON.stringify({ code }),
headers: {
"Content-Type": "application/json",
},
});
return response.json();
}
/**
* Check if the user is authenticated
* @param login The user's GitHub login handle
* @returns Whether the user is authenticated
* Authenticate with GitHub token
* @returns Response with authentication status and user info if successful
*/
static async isAuthenticated(login: string): Promise<boolean> {
const response = await fetch(`${OpenHands.BASE_URL}/api/authenticate`, {
method: "POST",
body: JSON.stringify({ login }),
headers: {
"Content-Type": "application/json",
static async authenticate(): Promise<Response> {
return request(
`/api/authenticate`,
{
method: "POST",
},
});
return response.status === 200;
true,
);
}
}
+6 -1
View File
@@ -27,11 +27,16 @@ export interface GitHubAccessTokenResponse {
access_token: string;
}
export interface AuthenticationResponse {
message: string;
login?: string; // Only present when allow list is enabled
}
export interface Feedback {
version: string;
email: string;
token: string;
feedback: "positive" | "negative";
polarity: "positive" | "negative";
permissions: "public" | "private";
trajectory: unknown[];
}
@@ -0,0 +1,42 @@
import { useFetcher } from "@remix-run/react";
import { ModalBackdrop } from "./modals/modal-backdrop";
import ModalBody from "./modals/ModalBody";
import ModalButton from "./buttons/ModalButton";
import {
BaseModalTitle,
BaseModalDescription,
} from "./modals/confirmation-modals/BaseModal";
export function AnalyticsConsentFormModal() {
const fetcher = useFetcher({ key: "set-consent" });
return (
<ModalBackdrop>
<fetcher.Form
method="POST"
action="/set-consent"
className="flex flex-col gap-2"
>
<ModalBody>
<BaseModalTitle title="Your Privacy Preferences" />
<BaseModalDescription>
We use tools to understand how our application is used to improve
your experience. You can enable or disable analytics. Your
preferences will be stored and can be updated anytime.
</BaseModalDescription>
<label className="flex gap-2 items-center self-start">
<input name="analytics" type="checkbox" defaultChecked />
Send anonymous usage data
</label>
<ModalButton
type="submit"
text="Confirm Preferences"
className="bg-primary text-white w-full hover:opacity-80"
/>
</ModalBody>
</fetcher.Form>
</ModalBackdrop>
);
}
+10 -46
View File
@@ -1,6 +1,5 @@
import { useDispatch, useSelector } from "react-redux";
import React from "react";
import { useFetcher } from "@remix-run/react";
import { useSocket } from "#/context/socket";
import { convertImageToBase64 } from "#/utils/convert-image-to-base-64";
import { ChatMessage } from "./chat-message";
@@ -13,10 +12,6 @@ import { RootState } from "#/store";
import AgentState from "#/types/AgentState";
import { generateAgentStateChangeEvent } from "#/services/agentStateService";
import { FeedbackModal } from "./feedback-modal";
import { Feedback } from "#/api/open-hands.types";
import { getToken } from "#/services/auth";
import { removeApiKey, removeUnwantedKeys } from "#/utils/utils";
import { clientAction } from "#/routes/submit-feedback";
import { useScrollToBottom } from "#/hooks/useScrollToBottom";
import TypingIndicator from "./chat/TypingIndicator";
import ConfirmationButtons from "./chat/ConfirmationButtons";
@@ -24,16 +19,13 @@ import { ErrorMessage } from "./error-message";
import { ContinueButton } from "./continue-button";
import { ScrollToBottomButton } from "./scroll-to-bottom-button";
const FEEDBACK_VERSION = "1.0";
const isErrorMessage = (
message: Message | ErrorMessage,
): message is ErrorMessage => "error" in message;
export function ChatInterface() {
const { send, events } = useSocket();
const { send } = useSocket();
const dispatch = useDispatch();
const fetcher = useFetcher<typeof clientAction>({ key: "feedback" });
const scrollRef = React.useRef<HTMLDivElement>(null);
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
useScrollToBottom(scrollRef);
@@ -44,7 +36,6 @@ export function ChatInterface() {
const [feedbackPolarity, setFeedbackPolarity] = React.useState<
"positive" | "negative"
>("positive");
const [feedbackShared, setFeedbackShared] = React.useState(0);
const [feedbackModalIsOpen, setFeedbackModalIsOpen] = React.useState(false);
const handleSendMessage = async (content: string, files: File[]) => {
@@ -71,30 +62,6 @@ export function ChatInterface() {
setFeedbackPolarity(polarity);
};
const handleSubmitFeedback = (
permissions: "private" | "public",
email: string,
) => {
const feedback: Feedback = {
version: FEEDBACK_VERSION,
feedback: feedbackPolarity,
email,
permissions,
token: getToken(),
trajectory: removeApiKey(removeUnwantedKeys(events)),
};
const formData = new FormData();
formData.append("feedback", JSON.stringify(feedback));
fetcher.submit(formData, {
action: "/submit-feedback",
method: "POST",
});
setFeedbackShared(messages.length);
};
return (
<div className="h-full flex flex-col justify-between">
<div
@@ -130,16 +97,14 @@ export function ChatInterface() {
<div className="flex flex-col gap-[6px] px-4 pb-4">
<div className="flex justify-between relative">
{feedbackShared !== messages.length && messages.length > 3 && (
<FeedbackActions
onPositiveFeedback={() =>
onClickShareFeedbackActionButton("positive")
}
onNegativeFeedback={() =>
onClickShareFeedbackActionButton("negative")
}
/>
)}
<FeedbackActions
onPositiveFeedback={() =>
onClickShareFeedbackActionButton("positive")
}
onNegativeFeedback={() =>
onClickShareFeedbackActionButton("negative")
}
/>
<div className="absolute left-1/2 transform -translate-x-1/2 bottom-0">
{messages.length > 2 &&
curAgentState === AgentState.AWAITING_USER_INPUT && (
@@ -163,9 +128,8 @@ export function ChatInterface() {
<FeedbackModal
isOpen={feedbackModalIsOpen}
isSubmitting={fetcher.state === "submitting"}
onClose={() => setFeedbackModalIsOpen(false)}
onSubmit={handleSubmitFeedback}
polarity={feedbackPolarity}
/>
</div>
);
+68 -14
View File
@@ -1,27 +1,81 @@
import React from "react";
import hotToast from "react-hot-toast";
import ModalButton from "./buttons/ModalButton";
import { Feedback } from "#/api/open-hands.types";
import OpenHands from "#/api/open-hands";
const FEEDBACK_VERSION = "1.0";
const VIEWER_PAGE = "https://www.all-hands.dev/share";
interface FeedbackFormProps {
onSubmit: (permissions: "private" | "public", email: string) => void;
onClose: () => void;
isSubmitting?: boolean;
polarity: "positive" | "negative";
}
export function FeedbackForm({
onSubmit,
onClose,
isSubmitting,
}: FeedbackFormProps) {
const handleSubmit = (event: React.FormEvent<HTMLFormElement>) => {
export function FeedbackForm({ onClose, polarity }: FeedbackFormProps) {
const [isSubmitting, setIsSubmitting] = React.useState(false);
const copiedToClipboardToast = () => {
hotToast("Password copied to clipboard", {
icon: "📋",
position: "bottom-right",
});
};
const onPressToast = (password: string) => {
navigator.clipboard.writeText(password);
copiedToClipboardToast();
};
const shareFeedbackToast = (
message: string,
link: string,
password: string,
) => {
hotToast(
<div className="flex flex-col gap-1">
<span>{message}</span>
<a
data-testid="toast-share-url"
className="text-blue-500 underline"
onClick={() => onPressToast(password)}
href={link}
target="_blank"
rel="noreferrer"
>
Go to shared feedback
</a>
<span onClick={() => onPressToast(password)} className="cursor-pointer">
Password: {password} <span className="text-gray-500">(copy)</span>
</span>
</div>,
{ duration: 10000 },
);
};
const handleSubmit = async (event: React.FormEvent<HTMLFormElement>) => {
event?.preventDefault();
const formData = new FormData(event.currentTarget);
setIsSubmitting(true);
const email = formData.get("email")?.toString();
const permissions = formData.get("permissions")?.toString() as
| "private"
| "public"
| undefined;
const email = formData.get("email")?.toString() || "";
const permissions = (formData.get("permissions")?.toString() ||
"private") as "private" | "public";
if (email) onSubmit(permissions || "private", email);
const feedback: Feedback = {
version: FEEDBACK_VERSION,
email,
polarity,
permissions,
trajectory: [],
token: "",
};
const response = await OpenHands.submitFeedback(feedback);
const { message, feedback_id, password } = response.body; // eslint-disable-line
const link = `${VIEWER_PAGE}?share_id=${feedback_id}`;
shareFeedbackToast(message, link, password);
setIsSubmitting(false);
};
return (
+3 -73
View File
@@ -1,6 +1,4 @@
import React from "react";
import hotToast, { toast } from "react-hot-toast";
import { useFetcher } from "@remix-run/react";
import { FeedbackForm } from "./feedback-form";
import {
BaseModalTitle,
@@ -8,82 +6,18 @@ import {
} from "./modals/confirmation-modals/BaseModal";
import { ModalBackdrop } from "./modals/modal-backdrop";
import ModalBody from "./modals/ModalBody";
import { clientAction } from "#/routes/submit-feedback";
interface FeedbackModalProps {
onSubmit: (permissions: "private" | "public", email: string) => void;
onClose: () => void;
isOpen: boolean;
isSubmitting?: boolean;
polarity: "positive" | "negative";
}
export function FeedbackModal({
onSubmit,
onClose,
isOpen,
isSubmitting,
polarity,
}: FeedbackModalProps) {
const fetcher = useFetcher<typeof clientAction>({ key: "feedback" });
const isInitialRender = React.useRef(true);
const copiedToClipboardToast = () => {
hotToast("Password copied to clipboard", {
icon: "📋",
position: "bottom-right",
});
};
const onPressToast = (password: string) => {
navigator.clipboard.writeText(password);
copiedToClipboardToast();
};
const shareFeedbackToast = (
message: string,
link: string,
password: string,
) => {
hotToast(
<div className="flex flex-col gap-1">
<span>{message}</span>
<a
data-testid="toast-share-url"
className="text-blue-500 underline"
onClick={() => onPressToast(password)}
href={link}
target="_blank"
rel="noreferrer"
>
Go to shared feedback
</a>
<span onClick={() => onPressToast(password)} className="cursor-pointer">
Password: {password} <span className="text-gray-500">(copy)</span>
</span>
</div>,
{ duration: 10000 },
);
};
React.useEffect(() => {
if (isInitialRender.current) {
isInitialRender.current = false;
return;
}
// Handle feedback submission
if (fetcher.state === "idle" && fetcher.data) {
if (!fetcher.data.success) {
toast.error("Error submitting feedback");
} else if (fetcher.data.data) {
const { data } = fetcher.data;
const { message, link, password } = data;
shareFeedbackToast(message, link, password);
}
onClose();
}
}, [fetcher.state, fetcher.data?.success]);
if (!isOpen) return null;
return (
@@ -91,11 +25,7 @@ export function FeedbackModal({
<ModalBody>
<BaseModalTitle title="Feedback" />
<BaseModalDescription description="To help us improve, we collect feedback from your interactions to improve our prompts. By submitting this form, you consent to us collecting this data." />
<FeedbackForm
onSubmit={onSubmit}
onClose={onClose}
isSubmitting={isSubmitting}
/>
<FeedbackForm onClose={onClose} polarity={polarity} />
</ModalBody>
</ModalBackdrop>
);
@@ -91,14 +91,15 @@ function ExplorerActions({
}
interface FileExplorerProps {
isOpen: boolean;
onToggle: () => void;
error: string | null;
}
function FileExplorer({ error }: FileExplorerProps) {
function FileExplorer({ error, isOpen, onToggle }: FileExplorerProps) {
const { revalidate } = useRevalidator();
const { paths, setPaths } = useFiles();
const [isHidden, setIsHidden] = React.useState(false);
const [isDragging, setIsDragging] = React.useState(false);
const { curAgentState } = useSelector((state: RootState) => state.agent);
@@ -117,52 +118,47 @@ function FileExplorer({ error }: FileExplorerProps) {
return;
}
dispatch(setRefreshID(Math.random()));
// TODO: Get token from data loader
const token = localStorage.getItem("token");
if (token) OpenHands.getFiles(token).then(setPaths);
OpenHands.getFiles().then(setPaths);
revalidate();
};
const uploadFileData = async (files: FileList) => {
try {
const token = localStorage.getItem("token");
if (token) {
const result = await OpenHands.uploadFiles(token, Array.from(files));
const result = await OpenHands.uploadFiles(Array.from(files));
if (isOpenHandsErrorResponse(result)) {
// Handle error response
toast.error(
`upload-error-${new Date().getTime()}`,
result.error || t(I18nKey.EXPLORER$UPLOAD_ERROR_MESSAGE),
);
return;
}
const uploadedCount = result.uploaded_files.length;
const skippedCount = result.skipped_files.length;
if (uploadedCount > 0) {
toast.success(
`upload-success-${new Date().getTime()}`,
t(I18nKey.EXPLORER$UPLOAD_SUCCESS_MESSAGE, {
count: uploadedCount,
}),
);
}
if (skippedCount > 0) {
const message = t(I18nKey.EXPLORER$UPLOAD_PARTIAL_SUCCESS_MESSAGE, {
count: skippedCount,
});
toast.info(message);
}
if (uploadedCount === 0 && skippedCount === 0) {
toast.info(t(I18nKey.EXPLORER$NO_FILES_UPLOADED_MESSAGE));
}
refreshWorkspace();
if (isOpenHandsErrorResponse(result)) {
// Handle error response
toast.error(
`upload-error-${new Date().getTime()}`,
result.error || t(I18nKey.EXPLORER$UPLOAD_ERROR_MESSAGE),
);
return;
}
const uploadedCount = result.uploaded_files.length;
const skippedCount = result.skipped_files.length;
if (uploadedCount > 0) {
toast.success(
`upload-success-${new Date().getTime()}`,
t(I18nKey.EXPLORER$UPLOAD_SUCCESS_MESSAGE, {
count: uploadedCount,
}),
);
}
if (skippedCount > 0) {
const message = t(I18nKey.EXPLORER$UPLOAD_PARTIAL_SUCCESS_MESSAGE, {
count: skippedCount,
});
toast.info(message);
}
if (uploadedCount === 0 && skippedCount === 0) {
toast.info(t(I18nKey.EXPLORER$NO_FILES_UPLOADED_MESSAGE));
}
refreshWorkspace();
} catch (e) {
// Handle unexpected errors (network issues, etc.)
toast.error(
@@ -211,7 +207,7 @@ function FileExplorer({ error }: FileExplorerProps) {
<div
className={twMerge(
"bg-neutral-800 h-full border-r-1 border-r-neutral-600 flex flex-col",
isHidden ? "w-12" : "w-60",
!isOpen ? "w-12" : "w-60",
)}
>
<div className="flex flex-col relative h-full px-3 py-2">
@@ -219,17 +215,17 @@ function FileExplorer({ error }: FileExplorerProps) {
<div
className={twMerge(
"flex items-center",
isHidden ? "justify-center" : "justify-between",
!isOpen ? "justify-center" : "justify-between",
)}
>
{!isHidden && (
{isOpen && (
<div className="text-neutral-300 font-bold text-sm">
{t(I18nKey.EXPLORER$LABEL_WORKSPACE)}
</div>
)}
<ExplorerActions
isHidden={isHidden}
toggleHidden={() => setIsHidden((prev) => !prev)}
isHidden={!isOpen}
toggleHidden={onToggle}
onRefresh={refreshWorkspace}
onUpload={selectFileInput}
/>
@@ -237,7 +233,7 @@ function FileExplorer({ error }: FileExplorerProps) {
</div>
{!error && (
<div className="overflow-auto flex-grow">
<div style={{ display: isHidden ? "none" : "block" }}>
<div style={{ display: !isOpen ? "none" : "block" }}>
<ExplorerTree files={paths} />
</div>
</div>
@@ -59,14 +59,11 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
return;
}
const token = localStorage.getItem("token");
if (token) {
try {
const newChildren = await OpenHands.getFiles(token, path);
setChildren(newChildren);
} catch (error) {
toast.error("Failed to fetch files");
}
try {
const newChildren = await OpenHands.getFiles(path);
setChildren(newChildren);
} catch (error) {
toast.error("Failed to fetch files");
}
};
@@ -77,15 +74,13 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
}, [refreshID, isOpen]);
const handleClick = async () => {
const token = localStorage.getItem("token");
if (isDirectory) {
setIsOpen((prev) => !prev);
} else if (token) {
} else {
const code = modifiedFiles[path] || files[path];
try {
const fetchedCode = await OpenHands.getFile(token, path);
const fetchedCode = await OpenHands.getFile(path);
setSelectedPath(path);
if (!code || fetchedCode !== files[path]) {
setFileContent(path, fetchedCode);
@@ -0,0 +1,94 @@
import React from "react";
import {
isGitHubErrorReponse,
retrieveAllGitHubUserRepositories,
} from "#/api/github";
import { SuggestionBox } from "#/routes/_oh._index/suggestion-box";
import { ConnectToGitHubModal } from "./modals/connect-to-github-modal";
import { ModalBackdrop } from "./modals/modal-backdrop";
import { GitHubRepositorySelector } from "#/routes/_oh._index/github-repo-selector";
import ModalButton from "./buttons/ModalButton";
import GitHubLogo from "#/assets/branding/github-logo.svg?react";
interface GitHubAuthProps {
onConnectToGitHub: () => void;
repositories: GitHubRepository[];
isLoggedIn: boolean;
}
function GitHubAuth({
onConnectToGitHub,
repositories,
isLoggedIn,
}: GitHubAuthProps) {
if (isLoggedIn) {
return <GitHubRepositorySelector repositories={repositories} />;
}
return (
<ModalButton
text="Connect to GitHub"
icon={<GitHubLogo width={20} height={20} />}
className="bg-[#791B80] w-full"
onClick={onConnectToGitHub}
/>
);
}
interface GitHubRepositoriesSuggestionBoxProps {
repositories: Awaited<
ReturnType<typeof retrieveAllGitHubUserRepositories>
> | null;
gitHubAuthUrl: string | null;
user: GitHubErrorReponse | GitHubUser | null;
}
export function GitHubRepositoriesSuggestionBox({
repositories,
gitHubAuthUrl,
user,
}: GitHubRepositoriesSuggestionBoxProps) {
const [connectToGitHubModalOpen, setConnectToGitHubModalOpen] =
React.useState(false);
const handleConnectToGitHub = () => {
if (gitHubAuthUrl) {
window.location.href = gitHubAuthUrl;
} else {
setConnectToGitHubModalOpen(true);
}
};
if (isGitHubErrorReponse(repositories)) {
return (
<SuggestionBox
title="Error Fetching Repositories"
content={
<p className="text-danger text-center">{repositories.message}</p>
}
/>
);
}
return (
<>
<SuggestionBox
title="Open a Repo"
content={
<GitHubAuth
isLoggedIn={!!user && !isGitHubErrorReponse(user)}
repositories={repositories || []}
onConnectToGitHub={handleConnectToGitHub}
/>
}
/>
{connectToGitHubModalOpen && (
<ModalBackdrop onClose={() => setConnectToGitHubModalOpen(false)}>
<ConnectToGitHubModal
onClose={() => setConnectToGitHubModalOpen(false)}
/>
</ModalBackdrop>
)}
</>
);
}
@@ -14,12 +14,14 @@ interface AccountSettingsModalProps {
onClose: () => void;
selectedLanguage: string;
gitHubError: boolean;
analyticsConsent: string | null;
}
function AccountSettingsModal({
onClose,
selectedLanguage,
gitHubError,
analyticsConsent,
}: AccountSettingsModalProps) {
const data = useRouteLoaderData<typeof clientLoader>("routes/_oh");
const settingsFetcher = useFetcher<typeof settingsClientAction>({
@@ -32,6 +34,7 @@ function AccountSettingsModal({
const formData = new FormData(event.currentTarget);
const language = formData.get("language")?.toString();
const ghToken = formData.get("ghToken")?.toString();
const analytics = formData.get("analytics")?.toString() === "on";
const accountForm = new FormData();
const loginForm = new FormData();
@@ -44,6 +47,7 @@ function AccountSettingsModal({
accountForm.append("language", languageKey ?? "en");
}
if (ghToken) loginForm.append("ghToken", ghToken);
accountForm.append("analytics", analytics.toString());
settingsFetcher.submit(accountForm, {
method: "POST",
@@ -101,6 +105,15 @@ function AccountSettingsModal({
)}
</div>
<label className="flex gap-2 items-center self-start">
<input
name="analytics"
type="checkbox"
defaultChecked={analyticsConsent === "true"}
/>
Enable analytics
</label>
<div className="flex flex-col gap-2 w-full">
<ModalButton
disabled={
@@ -21,13 +21,17 @@ export function BaseModalTitle({ title }: BaseModalTitleProps) {
}
interface BaseModalDescriptionProps {
description: React.ReactNode;
description?: React.ReactNode;
children?: React.ReactNode;
}
export function BaseModalDescription({
description,
children,
}: BaseModalDescriptionProps) {
return <span className="text-xs text-[#A3A3A3]">{description}</span>;
return (
<span className="text-xs text-[#A3A3A3]">{children || description}</span>
);
}
interface BaseModalProps {
+10 -6
View File
@@ -1,7 +1,6 @@
import React from "react";
import { Data } from "ws";
import EventLogger from "#/utils/event-logger";
import { getValidFallbackHost } from "#/utils/get-valid-fallback-host";
interface WebSocketClientOptions {
token: string | null;
@@ -46,12 +45,17 @@ function SocketProvider({ children }: SocketProviderProps) {
);
}
const fallback = getValidFallbackHost();
const baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || fallback;
const baseUrl =
import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host;
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const ws = new WebSocket(
`${protocol}//${baseUrl}/ws${options?.token ? `?token=${options.token}` : ""}`,
);
const sessionToken = options?.token || "NO_JWT"; // not allowed to be empty or duplicated
const ghToken = localStorage.getItem("ghToken") || "NO_GITHUB";
const ws = new WebSocket(`${protocol}//${baseUrl}/ws`, [
"openhands",
sessionToken,
ghToken,
]);
ws.addEventListener("open", (event) => {
setIsConnected(true);
+14 -1
View File
@@ -6,13 +6,25 @@
*/
import { RemixBrowser } from "@remix-run/react";
import { startTransition, StrictMode } from "react";
import React, { startTransition, StrictMode } from "react";
import { hydrateRoot } from "react-dom/client";
import { Provider } from "react-redux";
import posthog from "posthog-js";
import { SocketProvider } from "./context/socket";
import "./i18n";
import store from "./store";
function PosthogInit() {
React.useEffect(() => {
posthog.init("phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA", {
api_host: "https://us.i.posthog.com",
person_profiles: "identified_only",
});
}, []);
return null;
}
async function prepareApp() {
if (
process.env.NODE_ENV === "development" &&
@@ -34,6 +46,7 @@ prepareApp().then(() =>
<SocketProvider>
<Provider store={store}>
<RemixBrowser />
<PosthogInit />
</Provider>
</SocketProvider>
</StrictMode>,
+14 -9
View File
@@ -1,7 +1,7 @@
import { delay, http, HttpResponse } from "msw";
const openHandsHandlers = [
http.get("http://localhost:3000/api/options/models", async () => {
http.get("http://localhost:3001/api/options/models", async () => {
await delay();
return HttpResponse.json([
"gpt-3.5-turbo",
@@ -10,17 +10,17 @@ const openHandsHandlers = [
]);
}),
http.get("http://localhost:3000/api/options/agents", async () => {
http.get("http://localhost:3001/api/options/agents", async () => {
await delay();
return HttpResponse.json(["CodeActAgent", "CoActAgent"]);
}),
http.get("http://localhost:3000/api/options/security-analyzers", async () => {
http.get("http://localhost:3001/api/options/security-analyzers", async () => {
await delay();
return HttpResponse.json(["mock-invariant"]);
}),
http.get("http://localhost:3000/api/list-files", async ({ request }) => {
http.get("http://localhost:3001/api/list-files", async ({ request }) => {
await delay();
const token = request.headers
@@ -32,11 +32,11 @@ const openHandsHandlers = [
return HttpResponse.json(["file1.ts", "dir1/file2.ts", "file3.ts"]);
}),
http.post("http://localhost:3000/api/save-file", () =>
http.post("http://localhost:3001/api/save-file", () =>
HttpResponse.json(null, { status: 200 }),
),
http.get("http://localhost:3000/api/select-file", async ({ request }) => {
http.get("http://localhost:3001/api/select-file", async ({ request }) => {
await delay();
const token = request.headers
@@ -58,7 +58,7 @@ const openHandsHandlers = [
return HttpResponse.json(null, { status: 404 });
}),
http.post("http://localhost:3000/api/submit-feedback", async () => {
http.post("http://localhost:3001/api/submit-feedback", async () => {
await delay(1200);
return HttpResponse.json({
@@ -70,7 +70,9 @@ const openHandsHandlers = [
export const handlers = [
...openHandsHandlers,
http.get("https://api.github.com/user/repos", ({ request }) => {
http.get("https://api.github.com/user/repos", async ({ request }) => {
await delay(3500);
const token = request.headers
.get("Authorization")
?.replace("Bearer", "")
@@ -85,7 +87,10 @@ export const handlers = [
{ id: 2, full_name: "octocat/earth" },
]);
}),
http.post("http://localhost:3000/api/submit-feedback", async () =>
http.post("http://localhost:3001/api/submit-feedback", async () =>
HttpResponse.json({ statusCode: 200 }, { status: 200 }),
),
http.post("https://us.i.posthog.com/e", async () =>
HttpResponse.json(null, { status: 200 }),
),
];
+27 -69
View File
@@ -1,54 +1,23 @@
import {
Await,
ClientActionFunctionArgs,
ClientLoaderFunctionArgs,
json,
defer,
redirect,
useLoaderData,
useRouteLoaderData,
} from "@remix-run/react";
import React from "react";
import React, { Suspense } from "react";
import { SuggestionBox } from "./suggestion-box";
import { TaskForm } from "./task-form";
import { HeroHeading } from "./hero-heading";
import { GitHubRepositorySelector } from "./github-repo-selector";
import {
isGitHubErrorReponse,
retrieveAllGitHubUserRepositories,
} from "#/api/github";
import ModalButton from "#/components/buttons/ModalButton";
import GitHubLogo from "#/assets/branding/github-logo.svg?react";
import { ConnectToGitHubModal } from "#/components/modals/connect-to-github-modal";
import { ModalBackdrop } from "#/components/modals/modal-backdrop";
import { retrieveAllGitHubUserRepositories } from "#/api/github";
import store from "#/store";
import { setInitialQuery } from "#/state/initial-query-slice";
import { clientLoader as rootClientLoader } from "#/routes/_oh";
import OpenHands from "#/api/open-hands";
import { generateGitHubAuthUrl } from "#/utils/generate-github-auth-url";
interface GitHubAuthProps {
onConnectToGitHub: () => void;
repositories: GitHubRepository[];
isLoggedIn: boolean;
}
function GitHubAuth({
onConnectToGitHub,
repositories,
isLoggedIn,
}: GitHubAuthProps) {
if (isLoggedIn) {
return <GitHubRepositorySelector repositories={repositories} />;
}
return (
<ModalButton
text="Connect to GitHub"
icon={<GitHubLogo width={20} height={20} />}
className="bg-[#791B80] w-full"
onClick={onConnectToGitHub}
/>
);
}
import { GitHubRepositoriesSuggestionBox } from "#/components/github-repositories-suggestion-box";
export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
let isSaas = false;
@@ -67,12 +36,12 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
const token = localStorage.getItem("token");
if (token) return redirect("/app");
let repositories: GitHubRepository[] = [];
let repositories: ReturnType<
typeof retrieveAllGitHubUserRepositories
> | null = null;
if (ghToken) {
const data = await retrieveAllGitHubUserRepositories(ghToken);
if (!isGitHubErrorReponse(data)) {
repositories = data;
}
const data = retrieveAllGitHubUserRepositories(ghToken);
repositories = data;
}
let githubAuthUrl: string | null = null;
@@ -81,7 +50,7 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
githubAuthUrl = generateGitHubAuthUrl(githubClientId, requestUrl);
}
return json({ repositories, githubAuthUrl });
return defer({ repositories, githubAuthUrl });
};
export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
@@ -95,18 +64,8 @@ export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
function Home() {
const rootData = useRouteLoaderData<typeof rootClientLoader>("routes/_oh");
const { repositories, githubAuthUrl } = useLoaderData<typeof clientLoader>();
const [connectToGitHubModalOpen, setConnectToGitHubModalOpen] =
React.useState(false);
const [importedFile, setImportedFile] = React.useState<File | null>(null);
const handleConnectToGitHub = () => {
if (githubAuthUrl) {
window.location.href = githubAuthUrl;
} else {
setConnectToGitHubModalOpen(true);
}
};
return (
<div className="bg-root-secondary h-full rounded-xl flex flex-col items-center justify-center relative overflow-y-auto">
<HeroHeading />
@@ -115,18 +74,24 @@ function Home() {
<TaskForm importedProjectZip={importedFile} />
</div>
<div className="flex gap-4 w-full">
<SuggestionBox
title="Open a Repo"
content={
<GitHubAuth
isLoggedIn={
!!rootData?.user && !isGitHubErrorReponse(rootData.user)
}
repositories={repositories}
onConnectToGitHub={handleConnectToGitHub}
<Suspense
fallback={
<SuggestionBox
title="Open a Repo"
content="Loading repositories..."
/>
}
/>
>
<Await resolve={repositories}>
{(resolvedRepositories) => (
<GitHubRepositoriesSuggestionBox
repositories={resolvedRepositories}
gitHubAuthUrl={githubAuthUrl}
user={rootData?.user || null}
/>
)}
</Await>
</Suspense>
<SuggestionBox
title={importedFile ? "Project Loaded" : "+ Import Project"}
content={
@@ -159,13 +124,6 @@ function Home() {
/>
</div>
</div>
{connectToGitHubModalOpen && (
<ModalBackdrop onClose={() => setConnectToGitHubModalOpen(false)}>
<ConnectToGitHubModal
onClose={() => setConnectToGitHubModalOpen(false)}
/>
</ModalBackdrop>
)}
</div>
);
}
@@ -1,18 +1,21 @@
import { Editor, Monaco } from "@monaco-editor/react";
import { Editor, EditorProps } from "@monaco-editor/react";
import React from "react";
import { useTranslation } from "react-i18next";
import { VscCode } from "react-icons/vsc";
import { type editor } from "monaco-editor";
import toast from "react-hot-toast";
import { I18nKey } from "#/i18n/declaration";
import { useFiles } from "#/context/files";
import OpenHands from "#/api/open-hands";
interface CodeEditorCompoonentProps {
onMount: EditorProps["onMount"];
isReadOnly: boolean;
}
function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
function CodeEditorCompoonent({
onMount,
isReadOnly,
}: CodeEditorCompoonentProps) {
const { t } = useTranslation();
const {
files,
@@ -22,22 +25,6 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
saveFileContent: saveNewFileContent,
} = useFiles();
const handleEditorDidMount = React.useCallback(
(editor: editor.IStandaloneCodeEditor, monaco: Monaco): void => {
monaco.editor.defineTheme("my-theme", {
base: "vs-dark",
inherit: true,
rules: [],
colors: {
"editor.background": "#171717",
},
});
monaco.editor.setTheme("my-theme");
},
[],
);
const handleEditorChange = (value: string | undefined) => {
if (selectedPath && value) modifyFileContent(selectedPath, value);
};
@@ -49,8 +36,7 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
if (content) {
try {
const token = localStorage.getItem("token")?.toString();
if (token) await OpenHands.saveFile(token, selectedPath, content);
await OpenHands.saveFile(selectedPath, content);
} catch (error) {
toast.error("Failed to save file");
}
@@ -68,7 +54,7 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
return (
<div
data-testid="code-editor-empty-message"
className="flex flex-col items-center text-neutral-400"
className="flex flex-col h-full items-center justify-center text-neutral-400"
>
<VscCode size={100} />
{t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
@@ -79,7 +65,6 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
return (
<Editor
data-testid="code-editor"
height="100%"
path={selectedPath ?? undefined}
defaultValue=""
value={
@@ -87,7 +72,7 @@ function CodeEditorCompoonent({ isReadOnly }: CodeEditorCompoonentProps) {
? modifiedFiles[selectedPath] || files[selectedPath]
: undefined
}
onMount={handleEditorDidMount}
onMount={onMount}
onChange={handleEditorChange}
options={{ readOnly: isReadOnly }}
/>
+41 -16
View File
@@ -1,12 +1,13 @@
import React from "react";
import { useSelector } from "react-redux";
import { json, useLoaderData, useRouteError } from "@remix-run/react";
import { json, useRouteError } from "@remix-run/react";
import toast from "react-hot-toast";
import { editor } from "monaco-editor";
import { EditorProps } from "@monaco-editor/react";
import { RootState } from "#/store";
import AgentState from "#/types/AgentState";
import FileExplorer from "#/components/file-explorer/FileExplorer";
import OpenHands from "#/api/open-hands";
import { useSocket } from "#/context/socket";
import CodeEditorCompoonent from "./code-editor-component";
import { useFiles } from "#/context/files";
import { EditorActions } from "#/components/editor-actions";
@@ -28,8 +29,7 @@ export function ErrorBoundary() {
}
function CodeEditor() {
const { token } = useLoaderData<typeof clientLoader>();
const { runtimeActive } = useSocket();
const { curAgentState } = useSelector((state: RootState) => state.agent);
const {
setPaths,
selectedPath,
@@ -37,6 +37,27 @@ function CodeEditor() {
saveFileContent: saveNewFileContent,
discardChanges,
} = useFiles();
const [fileExplorerIsOpen, setFileExplorerIsOpen] = React.useState(true);
const editorRef = React.useRef<editor.IStandaloneCodeEditor | null>(null);
const toggleFileExplorer = () => {
setFileExplorerIsOpen((prev) => !prev);
editorRef.current?.layout({ width: 0, height: 0 });
};
const handleEditorDidMount: EditorProps["onMount"] = (e, monaco) => {
editorRef.current = e;
monaco.editor.defineTheme("oh-dark", {
base: "vs-dark",
inherit: true,
rules: [],
colors: {
"editor.background": "#171717",
},
});
monaco.editor.setTheme("oh-dark");
};
const [errors, setErrors] = React.useState<{ getFiles: string | null }>({
getFiles: null,
@@ -47,15 +68,14 @@ function CodeEditor() {
);
React.useEffect(() => {
// only retrieve files if connected to WS to prevent requesting before runtime is ready
if (runtimeActive && token) {
OpenHands.getFiles(token)
if (curAgentState === AgentState.INIT) {
OpenHands.getFiles()
.then(setPaths)
.catch(() => {
setErrors({ getFiles: "Failed to retrieve files" });
});
}
}, [runtimeActive, token]);
}, [curAgentState]);
// Code editing is only allowed when the agent is paused, finished, or awaiting user input (server rules)
const isEditingAllowed = React.useMemo(
@@ -69,9 +89,9 @@ function CodeEditor() {
const handleSave = async () => {
if (selectedPath) {
const content = modifiedFiles[selectedPath];
if (content && token) {
if (content) {
try {
await OpenHands.saveFile(token, selectedPath, content);
await OpenHands.saveFile(selectedPath, content);
saveNewFileContent(selectedPath);
} catch (error) {
toast.error("Failed to save file");
@@ -85,9 +105,13 @@ function CodeEditor() {
};
return (
<div className="flex h-full w-full bg-neutral-900 relative">
<FileExplorer error={errors.getFiles} />
<div className="flex flex-col min-h-0 w-full">
<div className="flex h-full bg-neutral-900 relative">
<FileExplorer
isOpen={fileExplorerIsOpen}
onToggle={toggleFileExplorer}
error={errors.getFiles}
/>
<div className="w-full">
{selectedPath && (
<div className="flex w-full items-center justify-between self-end p-2">
<span className="text-sm text-neutral-500">{selectedPath}</span>
@@ -98,9 +122,10 @@ function CodeEditor() {
/>
</div>
)}
<div className="flex grow items-center justify-center">
<CodeEditorCompoonent isReadOnly={!isEditingAllowed} />
</div>
<CodeEditorCompoonent
onMount={handleEditorDidMount}
isReadOnly={!isEditingAllowed}
/>
</div>
</div>
);
+5 -6
View File
@@ -72,9 +72,8 @@ const isAgentStateChange = (
export const clientLoader = async () => {
const ghToken = localStorage.getItem("ghToken");
try {
const isAuthed = await userIsAuthenticated(ghToken);
const isAuthed = await userIsAuthenticated();
if (!isAuthed) {
clearSession();
return redirect("/");
@@ -290,21 +289,21 @@ function App() {
React.useEffect(() => {
(async () => {
if (runtimeActive && token && importedProjectZip) {
if (runtimeActive && importedProjectZip) {
// upload files action
try {
const blob = base64ToBlob(importedProjectZip);
const file = new File([blob], "imported-project.zip", {
type: blob.type,
});
await OpenHands.uploadFiles(token, [file]);
await OpenHands.uploadFiles([file]);
dispatch(setImportedProjectZip(null));
} catch (error) {
toast.error("Failed to upload project files.");
}
}
})();
}, [runtimeActive, token, importedProjectZip]);
}, [runtimeActive, importedProjectZip]);
const {
isOpen: securityModalIsOpen,
@@ -315,7 +314,7 @@ function App() {
return (
<div className="flex flex-col h-full gap-3">
<div className="flex h-full overflow-auto gap-3">
<Container className="w-[375px] max-h-full">
<Container className="w-[390px] max-h-full">
<ChatInterface />
</Container>
+21 -2
View File
@@ -10,6 +10,8 @@ import {
Outlet,
ClientLoaderFunctionArgs,
} from "@remix-run/react";
import posthog from "posthog-js";
import { useDispatch } from "react-redux";
import { retrieveGitHubUser, isGitHubErrorReponse } from "#/api/github";
import OpenHands from "#/api/open-hands";
import CogTooth from "#/assets/cog-tooth";
@@ -28,6 +30,9 @@ import DocsIcon from "#/assets/docs.svg?react";
import { userIsAuthenticated } from "#/utils/user-is-authenticated";
import { generateGitHubAuthUrl } from "#/utils/generate-github-auth-url";
import { WaitlistModal } from "#/components/waitlist-modal";
import { AnalyticsConsentFormModal } from "#/components/analytics-consent-form-modal";
import { setCurrentAgentState } from "#/state/agentSlice";
import AgentState from "#/types/AgentState";
export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
try {
@@ -41,12 +46,20 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
let token = localStorage.getItem("token");
const ghToken = localStorage.getItem("ghToken");
const analyticsConsent = localStorage.getItem("analytics-consent");
const userConsents = analyticsConsent === "true";
let isAuthed: boolean = false;
if (!userConsents) {
posthog.opt_out_capturing();
} else {
posthog.opt_in_capturing();
}
let isAuthed = false;
let githubAuthUrl: string | null = null;
try {
isAuthed = await userIsAuthenticated(ghToken);
isAuthed = await userIsAuthenticated();
if (!isAuthed && window.__GITHUB_CLIENT_ID__) {
const requestUrl = new URL(request.url);
githubAuthUrl = generateGitHubAuthUrl(
@@ -79,6 +92,7 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
user,
settingsIsUpdated,
settings,
analyticsConsent,
});
};
@@ -132,9 +146,11 @@ export default function MainApp() {
githubAuthUrl,
settingsIsUpdated,
settings,
analyticsConsent,
} = useLoaderData<typeof clientLoader>();
const logoutFetcher = useFetcher({ key: "logout" });
const endSessionFetcher = useFetcher({ key: "end-session" });
const dispatch = useDispatch();
const [accountSettingsModalOpen, setAccountSettingsModalOpen] =
React.useState(false);
@@ -204,6 +220,7 @@ export default function MainApp() {
const handleEndSession = () => {
setStartNewProjectModalIsOpen(false);
dispatch(setCurrentAgentState(AgentState.LOADING));
// call new session action and redirect to '/'
endSessionFetcher.submit(new FormData(), {
method: "POST",
@@ -304,6 +321,7 @@ export default function MainApp() {
onClose={handleAccountSettingsModalClose}
selectedLanguage={settings.LANGUAGE}
gitHubError={isGitHubErrorReponse(user)}
analyticsConsent={analyticsConsent}
/>
</ModalBackdrop>
)}
@@ -328,6 +346,7 @@ export default function MainApp() {
{!isAuthed && (
<WaitlistModal ghToken={ghToken} githubAuthUrl={githubAuthUrl} />
)}
{!analyticsConsent && <AnalyticsConsentFormModal />}
</div>
);
}
@@ -11,11 +11,11 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
const code = url.searchParams.get("code");
if (code) {
// request to the server to exchange the code for a token
const { access_token: accessToken } =
await OpenHands.getGitHubAccessToken(code);
// set the token in local storage
localStorage.setItem("ghToken", accessToken);
return redirect("/");
}
+9
View File
@@ -0,0 +1,9 @@
import { ClientActionFunctionArgs, json } from "@remix-run/react";
export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
const formData = await request.formData();
const userConsents = formData.get("analytics") === "on";
localStorage.setItem("analytics-consent", userConsents.toString());
return json(null);
};
+3
View File
@@ -28,6 +28,9 @@ export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
const LANGUAGE = formData.get("language")?.toString();
if (LANGUAGE) saveSettings({ LANGUAGE });
const ANALYTICS = formData.get("analytics")?.toString() ?? "false";
localStorage.setItem("analytics-consent", ANALYTICS);
return json({ success: true });
}
-47
View File
@@ -1,47 +0,0 @@
import { ClientActionFunctionArgs, json } from "@remix-run/react";
import { Feedback } from "#/api/open-hands.types";
import OpenHands from "#/api/open-hands";
const VIEWER_PAGE = "https://www.all-hands.dev/share";
const isFeedback = (feedback: unknown): feedback is Feedback => {
if (typeof feedback !== "object" || feedback === null) {
return false;
}
return (
"version" in feedback &&
"email" in feedback &&
"token" in feedback &&
"feedback" in feedback &&
"permissions" in feedback &&
"trajectory" in feedback
);
};
export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
const formData = await request.formData();
const feedback = formData.get("feedback")?.toString();
const token = localStorage.getItem("token");
if (token && feedback) {
const parsed = JSON.parse(feedback);
if (isFeedback(parsed)) {
try {
const response = await OpenHands.sendFeedback(token, parsed);
if (response.statusCode === 200) {
const { message, feedback_id: feedbackId, password } = response.body;
const link = `${VIEWER_PAGE}?share_id=${feedbackId}`;
return json({
success: true,
data: { message, link, password },
});
}
} catch (error) {
return json({ success: false, data: null });
}
}
}
return json({ success: false, data: null });
};
+31 -3
View File
@@ -1,14 +1,26 @@
import { getToken } from "./auth";
import { getToken, getGitHubToken } from "./auth";
import toast from "#/utils/toast";
const WAIT_FOR_AUTH_DELAY_MS = 500;
const UNAUTHED_ROUTE_PREFIXES = [
"/api/authenticate",
"/api/options/",
"/config.json",
"/api/github/callback",
];
export async function request(
url: string,
options: RequestInit = {},
disableToast: boolean = false,
returnResponse: boolean = false,
maxRetries: number = 3,
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
): Promise<any> {
if (maxRetries < 0) {
throw new Error("Max retries exceeded");
}
const onFail = (msg: string) => {
if (!disableToast) {
toast.error("api", msg);
@@ -16,12 +28,17 @@ export async function request(
throw new Error(msg);
};
const needsAuth = !url.startsWith("/api/options/");
const needsAuth = !UNAUTHED_ROUTE_PREFIXES.some((prefix) =>
url.startsWith(prefix),
);
const token = getToken();
const githubToken = getGitHubToken();
if (!token && needsAuth) {
return new Promise((resolve) => {
setTimeout(() => {
resolve(request(url, options, disableToast));
resolve(
request(url, options, disableToast, returnResponse, maxRetries - 1),
);
}, WAIT_FOR_AUTH_DELAY_MS);
});
}
@@ -32,6 +49,13 @@ export async function request(
Authorization: `Bearer ${token}`,
};
}
if (githubToken) {
// eslint-disable-next-line no-param-reassign
options.headers = {
...(options.headers || {}),
"X-GitHub-Token": githubToken,
};
}
let response = null;
try {
@@ -48,6 +72,10 @@ export async function request(
onFail(`Error fetching ${url}: ${response?.statusText}`);
}
if (returnResponse) {
return response;
}
try {
return await (response && response.json());
} catch (e) {
+20 -1
View File
@@ -1,4 +1,5 @@
const TOKEN_KEY = "token";
const GITHUB_TOKEN_KEY = "ghToken";
const getToken = (): string => localStorage.getItem(TOKEN_KEY) ?? "";
@@ -10,4 +11,22 @@ const setToken = (token: string): void => {
localStorage.setItem(TOKEN_KEY, token);
};
export { getToken, setToken, clearToken };
const getGitHubToken = (): string =>
localStorage.getItem(GITHUB_TOKEN_KEY) ?? "";
const setGitHubToken = (token: string): void => {
localStorage.setItem(GITHUB_TOKEN_KEY, token);
};
const clearGitHubToken = (): void => {
localStorage.removeItem(GITHUB_TOKEN_KEY);
};
export {
getToken,
setToken,
clearToken,
getGitHubToken,
setGitHubToken,
clearGitHubToken,
};
+1 -6
View File
@@ -4,12 +4,7 @@ import OpenHands from "#/api/open-hands";
* Downloads the current workspace as a .zip file.
*/
export const downloadWorkspace = async () => {
const token = localStorage.getItem("token");
if (!token) {
throw new Error("No token found");
}
const blob = await OpenHands.getWorkspaceZip(token);
const blob = await OpenHands.getWorkspaceZip();
const url = URL.createObjectURL(blob);
const link = document.createElement("a");
@@ -1,19 +0,0 @@
/**
* Get the valid fallback host. Returns the host unless it is localhost, in which case it returns localhost:3000
* @returns Valid fallback host
*
* @example
* // If the host is localhost (e.g., localhost:5173), it returns localhost:3000
* const host = getValidFallbackHost(); // localhost:3000
*
* // If the host is not localhost, it returns the host
* const host = getValidFallbackHost(); // sub.example.com
*/
export const getValidFallbackHost = () => {
if (typeof window !== "undefined") {
return window.location.host;
}
// Fallback is localhost:3000 because that is the default port for the server
return "localhost:3000";
};
+7 -11
View File
@@ -1,16 +1,12 @@
import { retrieveGitHubUser, isGitHubErrorReponse } from "#/api/github";
import OpenHands from "#/api/open-hands";
export const userIsAuthenticated = async (ghToken: string | null) => {
if (window.__APP_MODE__ !== "saas") return true;
export const userIsAuthenticated = async () => {
if (window.__APP_MODE__ === "oss") return true;
let user: GitHubUser | GitHubErrorReponse | null = null;
if (ghToken) user = await retrieveGitHubUser(ghToken);
if (user && !isGitHubErrorReponse(user)) {
const isAuthed = await OpenHands.isAuthenticated(user.login);
return isAuthed;
try {
await OpenHands.authenticate();
return true;
} catch (error) {
return false;
}
return false;
};
+21 -13
View File
@@ -37,21 +37,29 @@ export const removeUnwantedKeys = (
"focused_element_bid",
];
return data.map((item) => {
// Create a shallow copy of item
const newItem = { ...item };
return data
.filter((item) => {
// Skip items that have a status key
if ("status" in item) {
return false;
}
return true;
})
.map((item) => {
// Create a shallow copy of item
const newItem = { ...item };
// Check if extras exists and delete it from a new extras object
if (newItem.extras) {
const newExtras = { ...newItem.extras };
UNDESIRED_KEYS.forEach((key) => {
delete newExtras[key as keyof typeof newExtras];
});
newItem.extras = newExtras;
}
// Check if extras exists and delete it from a new extras object
if (newItem.extras) {
const newExtras = { ...newItem.extras };
UNDESIRED_KEYS.forEach((key) => {
delete newExtras[key as keyof typeof newExtras];
});
newItem.extras = newExtras;
}
return newItem;
});
return newItem;
});
};
export const removeApiKey = (
@@ -16,6 +16,7 @@ from openhands.events.action import (
Action,
AgentDelegateAction,
AgentFinishAction,
BrowseInteractiveAction,
CmdRunAction,
FileEditAction,
IPythonRunCellAction,
@@ -23,6 +24,7 @@ from openhands.events.action import (
)
from openhands.events.observation import (
AgentDelegateObservation,
BrowserOutputObservation,
CmdOutputObservation,
FileEditObservation,
IPythonRunCellObservation,
@@ -42,7 +44,7 @@ from openhands.utils.prompt import PromptManager
class CodeActAgent(Agent):
VERSION = '2.1'
VERSION = '2.2'
"""
The Code Act Agent is a minimalist agent.
The agent works by passing the model a list of action-observation pairs and prompting the model to take the next step.
@@ -105,11 +107,11 @@ class CodeActAgent(Agent):
if self.function_calling_active:
# Function calling mode
self.tools = codeact_function_calling.get_tools(
codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate,
codeact_enable_browsing=self.config.codeact_enable_browsing,
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
)
logger.info(
logger.debug(
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2)}'
)
self.system_prompt = codeact_function_calling.SYSTEM_PROMPT
@@ -142,10 +144,10 @@ class CodeActAgent(Agent):
Args:
action (Action): The action to convert. Can be one of:
- AgentDelegateAction: For delegating tasks to other agents
- CmdRunAction: For executing bash commands
- IPythonRunCellAction: For running IPython code
- FileEditAction: For editing files
- BrowseInteractiveAction: For browsing the web
- AgentFinishAction: For ending the interaction
- MessageAction: For sending messages
pending_tool_call_action_messages (dict[str, Message]): Dictionary mapping response IDs
@@ -169,6 +171,7 @@ class CodeActAgent(Agent):
CmdRunAction,
IPythonRunCellAction,
FileEditAction,
BrowseInteractiveAction,
),
) or (isinstance(action, AgentFinishAction) and action.source == 'agent'):
if self.function_calling_active:
@@ -185,13 +188,17 @@ class CodeActAgent(Agent):
pending_tool_call_action_messages[llm_response.id] = Message(
role=assistant_msg.role,
# tool call content SHOULD BE a string
content=[TextContent(text=assistant_msg.content)]
content=[TextContent(text=assistant_msg.content or '')]
if assistant_msg.content is not None
else [],
tool_calls=assistant_msg.tool_calls,
)
return []
else:
assert not isinstance(action, BrowseInteractiveAction), (
'BrowseInteractiveAction is not supported in non-function calling mode. Action: '
+ str(action)
)
content = [TextContent(text=self.action_parser.action_to_str(action))]
return [
Message(
@@ -201,7 +208,7 @@ class CodeActAgent(Agent):
]
elif isinstance(action, MessageAction):
role = 'user' if action.source == 'user' else 'assistant'
content = [TextContent(text=action.content)]
content = [TextContent(text=action.content or '')]
if self.llm.vision_is_active() and action.images_urls:
content.append(ImageContent(image_urls=action.images_urls))
return [
@@ -266,6 +273,12 @@ class CodeActAgent(Agent):
elif isinstance(obs, FileEditObservation):
text = obs_prefix + truncate_content(str(obs), max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, BrowserOutputObservation):
text = obs.get_agent_obs_text()
message = Message(
role='user',
content=[TextContent(text=obs_prefix + text)],
)
elif isinstance(obs, AgentDelegateObservation):
text = obs_prefix + truncate_content(
obs.outputs['content'] if 'content' in obs.outputs else '',
@@ -335,6 +348,7 @@ class CodeActAgent(Agent):
}
if self.function_calling_active:
params['tools'] = self.tools
params['parallel_tool_calls'] = False
else:
params['stop'] = [
'</execute_ipython>',
@@ -5,6 +5,7 @@ This is similar to the functionality of `CodeActResponseParser`.
import json
from browsergym.core.action.highlevel import HighLevelActionSet
from litellm import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
@@ -16,6 +17,7 @@ from openhands.events.action import (
Action,
AgentDelegateAction,
AgentFinishAction,
BrowseInteractiveAction,
CmdRunAction,
FileEditAction,
IPythonRunCellAction,
@@ -272,24 +274,146 @@ StrReplaceEditorTool = ChatCompletionToolParam(
),
)
_BROWSER_DELEGATION = """Delegate the task to another browsing agent.
The assistant should delegate the task if it needs to browse the Internet.
# from browsergym/core/action/highlevel.py
_browser_action_space = HighLevelActionSet(
subsets=['bid', 'nav'],
strict=False, # less strict on the parsing of the actions
multiaction=True, # enable to agent to take multiple actions at once
)
_BROWSER_DESCRIPTION = """Interact with the browser using Python code.
The following 15 functions are available. Nothing else is supported.
goto(url: str)
Description: Navigate to a url.
Examples:
goto('http://www.example.com')
go_back()
Description: Navigate to the previous page in history.
Examples:
go_back()
go_forward()
Description: Navigate to the next page in history.
Examples:
go_forward()
noop(wait_ms: float = 1000)
Description: Do nothing, and optionally wait for the given time (in milliseconds).
You can use this to get the current page content and/or wait for the page to load.
Examples:
noop()
noop(500)
scroll(delta_x: float, delta_y: float)
Description: Scroll horizontally and vertically. Amounts in pixels, positive for right or down scrolling, negative for left or up scrolling. Dispatches a wheel event.
Examples:
scroll(0, 200)
scroll(-50.2, -100.5)
fill(bid: str, value: str)
Description: Fill out a form field. It focuses the element and triggers an input event with the entered text. It works for <input>, <textarea> and [contenteditable] elements.
Examples:
fill('237', 'example value')
fill('45', 'multi-line\nexample')
fill('a12', 'example with "quotes"')
select_option(bid: str, options: str | list[str])
Description: Select one or multiple options in a <select> element. You can specify option value or label to select. Multiple options can be selected.
Examples:
select_option('a48', 'blue')
select_option('c48', ['red', 'green', 'blue'])
click(bid: str, button: Literal['left', 'middle', 'right'] = 'left', modifiers: list[typing.Literal['Alt', 'Control', 'ControlOrMeta', 'Meta', 'Shift']] = [])
Description: Click an element.
Examples:
click('a51')
click('b22', button='right')
click('48', button='middle', modifiers=['Shift'])
dblclick(bid: str, button: Literal['left', 'middle', 'right'] = 'left', modifiers: list[typing.Literal['Alt', 'Control', 'ControlOrMeta', 'Meta', 'Shift']] = [])
Description: Double click an element.
Examples:
dblclick('12')
dblclick('ca42', button='right')
dblclick('178', button='middle', modifiers=['Shift'])
hover(bid: str)
Description: Hover over an element.
Examples:
hover('b8')
press(bid: str, key_comb: str)
Description: Focus the matching element and press a combination of keys. It accepts the logical key names that are emitted in the keyboardEvent.key property of the keyboard events: Backquote, Minus, Equal, Backslash, Backspace, Tab, Delete, Escape, ArrowDown, End, Enter, Home, Insert, PageDown, PageUp, ArrowRight, ArrowUp, F1 - F12, Digit0 - Digit9, KeyA - KeyZ, etc. You can alternatively specify a single character you'd like to produce such as "a" or "#". Following modification shortcuts are also supported: Shift, Control, Alt, Meta, ShiftLeft, ControlOrMeta. ControlOrMeta resolves to Control on Windows and Linux and to Meta on macOS.
Examples:
press('88', 'Backspace')
press('a26', 'ControlOrMeta+a')
press('a61', 'Meta+Shift+t')
focus(bid: str)
Description: Focus the matching element.
Examples:
focus('b455')
clear(bid: str)
Description: Clear the input field.
Examples:
clear('996')
drag_and_drop(from_bid: str, to_bid: str)
Description: Perform a drag & drop. Hover the element that will be dragged. Press left mouse button. Move mouse to the element that will receive the drop. Release left mouse button.
Examples:
drag_and_drop('56', '498')
upload_file(bid: str, file: str | list[str])
Description: Click an element and wait for a "filechooser" event, then select one or multiple input files for upload. Relative file paths are resolved relative to the current working directory. An empty list clears the selected files.
Examples:
upload_file('572', '/home/user/my_receipt.pdf')
upload_file('63', ['/home/bob/Documents/image.jpg', '/home/bob/Documents/file.zip'])
Multiple actions can be provided at once, but will be executed sequentially without any feedback from the page.
More than 2-3 actions usually leads to failure or unexpected behavior. Example:
fill('a12', 'example with "quotes"')
click('a51')
click('48', button='middle', modifiers=['Shift'])
"""
BrowserDelegationTool = ChatCompletionToolParam(
for _, action in _browser_action_space.action_set.items():
assert (
action.signature in _BROWSER_DESCRIPTION
), f'Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.signature}'
assert (
action.description in _BROWSER_DESCRIPTION
), f'Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.description}'
BrowserTool = ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='delegate_to_browsing_agent',
description=_BROWSER_DELEGATION,
name='browser',
description=_BROWSER_DESCRIPTION,
parameters={
'type': 'object',
'properties': {
'task': {
'code': {
'type': 'string',
'description': 'The task for the browsing agent to execute. It should include all the necessary context and specify what information the browsing agent should return.',
},
'description': 'The Python code that interacts with the browser.',
}
},
'required': ['task'],
'required': ['code'],
},
),
)
@@ -357,6 +481,8 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
f'TOOL CALL: str_replace_editor -> file_editor with code: {code}'
)
action = IPythonRunCellAction(code=code, include_extra=False)
elif tool_call.function.name == 'browser':
action = BrowseInteractiveAction(browser_actions=arguments['code'])
else:
raise RuntimeError(f'Unknown tool call: {tool_call.function.name}')
@@ -381,13 +507,13 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
def get_tools(
codeact_enable_browsing_delegate: bool = False,
codeact_enable_browsing: bool = False,
codeact_enable_llm_editor: bool = False,
codeact_enable_jupyter: bool = False,
) -> list[ChatCompletionToolParam]:
tools = [CmdRunTool, FinishTool]
if codeact_enable_browsing_delegate:
tools.append(BrowserDelegationTool)
if codeact_enable_browsing:
tools.append(BrowserTool)
if codeact_enable_jupyter:
tools.append(IPythonTool)
if codeact_enable_llm_editor:
+16 -6
View File
@@ -156,7 +156,7 @@ class AgentController:
if exception is not None and isinstance(exception, litellm.AuthenticationError):
detail = 'Please check your credentials. Is your API key correct?'
self.event_stream.add_event(
ErrorObservation(f'{message}:{detail}'), EventSource.USER
ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT
)
async def start_step_loop(self):
@@ -346,7 +346,8 @@ class AgentController:
self.state.agent_state = new_state
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
AgentStateChangedObservation('', self.state.agent_state),
EventSource.ENVIRONMENT,
)
if new_state == AgentState.INIT and self.state.resume_state:
@@ -423,7 +424,8 @@ class AgentController:
if self._is_stuck():
# This need to go BEFORE report_error to sync metrics
self.event_stream.add_event(
FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
FatalErrorObservation('Agent got stuck in a loop'),
EventSource.ENVIRONMENT,
)
return
@@ -507,6 +509,16 @@ class AgentController:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# emit AgentDelegateObservation when the delegate terminates due to error
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
content = (
f'{self.delegate.agent.name} encountered an error during execution.'
)
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
self.event_stream.add_event(obs, EventSource.AGENT)
# close the delegate upon error
await self.delegate.close()
self.delegate = None
@@ -532,9 +544,7 @@ class AgentController:
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs: Observation = AgentDelegateObservation(
outputs=outputs, content=content
)
obs = AgentDelegateObservation(outputs=outputs, content=content)
# clean up delegate status
self.delegate = None
+2 -2
View File
@@ -61,7 +61,7 @@ def display_event(event: Event):
if hasattr(event, 'thought'):
display_message(event.thought)
if isinstance(event, MessageAction):
if event.source != EventSource.USER:
if event.source == EventSource.AGENT:
display_message(event.content)
if isinstance(event, CmdRunAction):
display_command(event.command)
@@ -131,7 +131,7 @@ async def main():
next_message = input('How can I help? >> ')
if next_message == 'exit':
event_stream.add_event(
ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
)
return
action = MessageAction(content=next_message)
+2 -2
View File
@@ -9,7 +9,7 @@ class AgentConfig:
Attributes:
function_calling: Whether function calling is enabled. Default is True.
codeact_enable_browsing_delegate: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling.
codeact_enable_browsing: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling.
codeact_enable_llm_editor: Whether LLM editor is enabled in the action space. Default is False. Only works with function calling.
codeact_enable_jupyter: Whether Jupyter is enabled in the action space. Default is False.
micro_agent_name: The name of the micro agent to use for this agent.
@@ -19,7 +19,7 @@ class AgentConfig:
"""
function_calling: bool = True
codeact_enable_browsing_delegate: bool = True
codeact_enable_browsing: bool = True
codeact_enable_llm_editor: bool = False
codeact_enable_jupyter: bool = True
micro_agent_name: str | None = None
+2
View File
@@ -49,6 +49,8 @@ class ImageContent(Content):
class Message(BaseModel):
# NOTE: this is not the same as EventSource
# These are the roles in the LLM's APIs
role: Literal['user', 'system', 'assistant', 'tool']
content: list[TextContent | ImageContent] = Field(default_factory=list)
cache_enabled: bool = False
+1
View File
@@ -9,6 +9,7 @@ from openhands.llm.metrics import Metrics
class EventSource(str, Enum):
AGENT = 'agent'
USER = 'user'
ENVIRONMENT = 'environment'
@dataclass
+46 -2
View File
@@ -1,5 +1,7 @@
from dataclasses import dataclass, field
from browsergym.utils.obs import flatten_axtree_to_str
from openhands.core.schema import ObservationType
from openhands.events.observation.observation import Observation
@@ -29,7 +31,7 @@ class BrowserOutputObservation(Observation):
return 'Visited ' + self.url
def __str__(self) -> str:
return (
ret = (
'**BrowserOutputObservation**\n'
f'URL: {self.url}\n'
f'Error: {self.error}\n'
@@ -38,5 +40,47 @@ class BrowserOutputObservation(Observation):
f'Last browser action: {self.last_browser_action}\n'
f'Last browser action error: {self.last_browser_action_error}\n'
f'Focused element bid: {self.focused_element_bid}\n'
f'CONTENT: {self.content}\n'
f'Content: {self.content}\n'
)
ret += '--- Agent Observation ---\n'
ret += self.get_agent_obs_text()
return ret
def get_agent_obs_text(self) -> str:
"""Get a concise text that will be shown to the agent."""
text = f'[Current URL: {self.url}]\n'
text += f'[Focused element bid: {self.focused_element_bid}]\n\n'
if self.error:
text += (
'================ BEGIN error message ===============\n'
'The following error occurred when executing the last action:\n'
f'{self.last_browser_action_error}\n'
'================ END error message ===============\n'
)
else:
text += '[Action executed successfully.]\n'
try:
# We do not filter visible only here because we want to show the full content
# of the web page to the agent for simplicity.
# FIXME: handle the case when the web page is too large
cur_axtree_txt = self.get_axtree_str(filter_visible_only=False)
text += (
f'============== BEGIN accessibility tree ==============\n'
f'{cur_axtree_txt}\n'
f'============== END accessibility tree ==============\n'
)
except Exception as e:
text += f'\n[Error encountered when processing the accessibility tree: {e}]'
return text
def get_axtree_str(self, filter_visible_only: bool = False) -> str:
cur_axtree_txt = flatten_axtree_to_str(
self.axtree_object,
extra_properties=self.extra_element_properties,
with_clickable=True,
skip_generic=False,
filter_visible_only=filter_visible_only,
)
self._axtree_str = cur_axtree_txt
return cur_axtree_txt
+28 -8
View File
@@ -11,6 +11,7 @@ from openhands.events.event import Event, EventSource
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.storage import FileStore
from openhands.utils.async_utils import call_sync_from_async
class EventStreamSubscriber(str, Enum):
@@ -22,14 +23,29 @@ class EventStreamSubscriber(str, Enum):
TEST = 'test'
def session_exists(sid: str, file_store: FileStore) -> bool:
async def session_exists(sid: str, file_store: FileStore) -> bool:
try:
file_store.list(f'sessions/{sid}')
await call_sync_from_async(file_store.list, f'sessions/{sid}')
return True
except FileNotFoundError:
return False
class AsyncEventStreamWrapper:
def __init__(self, event_stream, *args, **kwargs):
self.event_stream = event_stream
self.args = args
self.kwargs = kwargs
async def __aiter__(self):
loop = asyncio.get_running_loop()
# Create an async generator that yields events
for event in self.event_stream.get_events(*self.args, **self.kwargs):
# Run the blocking get_events() in a thread pool
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
@dataclass
class EventStream:
sid: str
@@ -71,7 +87,15 @@ class EventStream:
end_id=None,
reverse=False,
filter_out_type: tuple[type[Event], ...] | None = None,
filter_hidden=False,
) -> Iterable[Event]:
def should_filter(event: Event):
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
return True
if filter_out_type is not None and isinstance(event, filter_out_type):
return True
return False
if reverse:
if end_id is None:
end_id = self._cur_id - 1
@@ -79,9 +103,7 @@ class EventStream:
while event_id >= start_id:
try:
event = self.get_event(event_id)
if filter_out_type is None or not isinstance(
event, filter_out_type
):
if not should_filter(event):
yield event
except FileNotFoundError:
logger.debug(f'No event found for ID {event_id}')
@@ -93,9 +115,7 @@ class EventStream:
break
try:
event = self.get_event(event_id)
if filter_out_type is None or not isinstance(
event, filter_out_type
):
if not should_filter(event):
yield event
except FileNotFoundError:
break
+108 -80
View File
@@ -82,6 +82,7 @@ class LLM(RetryMixin, DebugMixin):
config: The LLM configuration.
metrics: The metrics to use.
"""
self._tried_model_info = False
self.metrics: Metrics = (
metrics if metrics is not None else Metrics(model_name=config.model)
)
@@ -91,56 +92,6 @@ class LLM(RetryMixin, DebugMixin):
# litellm actually uses base Exception here for unknown model
self.model_info: ModelInfo | None = None
try:
if self.config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.config.model)
except Exception as e:
logger.debug(f'Error getting model info: {e}')
if self.config.model.startswith('litellm_proxy/'):
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
# GET {base_url}/v1/model/info with litellm_model_id as path param
response = requests.get(
f'{self.config.base_url}/v1/model/info',
headers={'Authorization': f'Bearer {self.config.api_key}'},
)
resp_json = response.json()
if 'data' not in resp_json:
logger.error(
f'Error getting model info from LiteLLM proxy: {resp_json}'
)
all_model_info = resp_json.get('data', [])
current_model_info = next(
(
info
for info in all_model_info
if info['model_name']
== self.config.model.removeprefix('litellm_proxy/')
),
None,
)
if current_model_info:
self.model_info = current_model_info['model_info']
# Last two attempts to get model info from NAME
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split(':')[0]
)
# noinspection PyBroadException
except Exception:
pass
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split('/')[-1]
)
# noinspection PyBroadException
except Exception:
pass
logger.debug(f'Model info: {self.model_info}')
if self.config.log_completions:
if self.config.log_completions_folder is None:
raise RuntimeError(
@@ -148,32 +99,26 @@ class LLM(RetryMixin, DebugMixin):
)
os.makedirs(self.config.log_completions_folder, exist_ok=True)
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
and isinstance(self.model_info['max_input_tokens'], int)
):
self.config.max_input_tokens = self.model_info['max_input_tokens']
else:
# Safe fallback for any potentially viable model
self.config.max_input_tokens = 4096
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
)
if self.config.max_output_tokens is None:
# Safe default for any potentially viable model
self.config.max_output_tokens = 4096
if self.model_info is not None:
# max_output_tokens has precedence over max_tokens, if either exists.
# litellm has models with both, one or none of these 2 parameters!
if 'max_output_tokens' in self.model_info and isinstance(
self.model_info['max_output_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_output_tokens']
elif 'max_tokens' in self.model_info and isinstance(
self.model_info['max_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_tokens']
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')
self._completion = partial(
litellm_completion,
@@ -207,6 +152,7 @@ class LLM(RetryMixin, DebugMixin):
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
self.init_model_info()
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
@@ -300,6 +246,87 @@ class LLM(RetryMixin, DebugMixin):
"""
return self._completion
def init_model_info(self):
if self._tried_model_info:
return
self._tried_model_info = True
try:
if self.config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.config.model)
except Exception as e:
logger.debug(f'Error getting model info: {e}')
if self.config.model.startswith('litellm_proxy/'):
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
# GET {base_url}/v1/model/info with litellm_model_id as path param
response = requests.get(
f'{self.config.base_url}/v1/model/info',
headers={'Authorization': f'Bearer {self.config.api_key}'},
)
resp_json = response.json()
if 'data' not in resp_json:
logger.error(
f'Error getting model info from LiteLLM proxy: {resp_json}'
)
all_model_info = resp_json.get('data', [])
current_model_info = next(
(
info
for info in all_model_info
if info['model_name']
== self.config.model.removeprefix('litellm_proxy/')
),
None,
)
if current_model_info:
self.model_info = current_model_info['model_info']
# Last two attempts to get model info from NAME
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split(':')[0]
)
# noinspection PyBroadException
except Exception:
pass
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split('/')[-1]
)
# noinspection PyBroadException
except Exception:
pass
logger.debug(f'Model info: {self.model_info}')
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
and isinstance(self.model_info['max_input_tokens'], int)
):
self.config.max_input_tokens = self.model_info['max_input_tokens']
else:
# Safe fallback for any potentially viable model
self.config.max_input_tokens = 4096
if self.config.max_output_tokens is None:
# Safe default for any potentially viable model
self.config.max_output_tokens = 4096
if self.model_info is not None:
# max_output_tokens has precedence over max_tokens, if either exists.
# litellm has models with both, one or none of these 2 parameters!
if 'max_output_tokens' in self.model_info and isinstance(
self.model_info['max_output_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_output_tokens']
elif 'max_tokens' in self.model_info and isinstance(
self.model_info['max_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_tokens']
def vision_is_active(self):
return not self.config.disable_vision and self._supports_vision()
@@ -324,14 +351,15 @@ class LLM(RetryMixin, DebugMixin):
Returns:
boolean: True if prompt caching is supported and enabled for the given model.
"""
return (
self.config.caching_prompt is True
and self.model_info is not None
and self.model_info.get('supports_prompt_caching', False)
and (
return self.config.caching_prompt is True and (
(
self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
)
or (
self.model_info is not None
and self.model_info.get('supports_prompt_caching', False)
)
)
def is_function_calling_active(self) -> bool:
+2
View File
@@ -12,6 +12,7 @@ from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import FatalErrorObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import event_to_dict
from openhands.events.stream import EventStream
@@ -33,6 +34,7 @@ class ShortTermHistory(list[Event]):
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
FatalErrorObservation,
)
def __init__(self):
+2
View File
@@ -136,6 +136,8 @@ class Runtime(FileEditRuntimeMixin):
)
observation._cause = event.id # type: ignore[attr-defined]
observation.tool_call_metadata = event.tool_call_metadata
# this might be unnecessary, since source should be set by the event stream when we're here
source = event.source if event.source else EventSource.AGENT
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
+5 -1
View File
@@ -81,7 +81,10 @@ class BrowserEnv:
raise ValueError(
f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
)
env = gym.make(self.browsergym_eval_env)
env = gym.make(
self.browsergym_eval_env,
tags_to_mark='all',
)
else:
env = gym.make(
'browsergym/openended',
@@ -89,6 +92,7 @@ class BrowserEnv:
wait_for_user_message=False,
headless=True,
disable_env_checker=True,
tags_to_mark='all',
)
obs, info = env.reset()
+1 -1
View File
@@ -17,7 +17,7 @@ class DockerRuntimeBuilder(RuntimeBuilder):
version_info = self.docker_client.version()
server_version = version_info.get('Version', '').replace('-', '.')
if tuple(map(int, server_version.split('.'))) < (18, 9):
if tuple(map(int, server_version.split('.')[:2])) < (18, 9):
raise RuntimeError('Docker server version must be >= 18.09 to use BuildKit')
self.rolling_logger = RollingLogger(max_lines=10)
+1 -3
View File
@@ -4,9 +4,7 @@ import tarfile
from glob import glob
from e2b import Sandbox as E2BSandbox
from e2b.sandbox.exception import (
TimeoutException,
)
from e2b.sandbox.exception import TimeoutException
from openhands.core.config import SandboxConfig
from openhands.core.logger import openhands_logger as logger
@@ -240,7 +240,7 @@ class EventStreamRuntime(Runtime):
@tenacity.retry(
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
wait=tenacity.wait_fixed(5),
)
def _init_container(self):
try:
@@ -46,13 +46,13 @@ class EditTool:
if command == 'view':
return self.view(_path, view_range)
elif command == 'create':
if not file_text:
if file_text is None:
raise ToolError('Parameter `file_text` is required for command: create')
self.write_file(_path, file_text)
self._file_history[_path].append(file_text)
return ToolResult(output=f'File created successfully at: {_path}')
elif command == 'str_replace':
if not old_str:
if old_str is None:
raise ToolError(
'Parameter `old_str` is required for command: str_replace'
)
@@ -62,7 +62,7 @@ class EditTool:
raise ToolError(
'Parameter `insert_line` is required for command: insert'
)
if not new_str:
if new_str is None:
raise ToolError('Parameter `new_str` is required for command: insert')
return self.insert(_path, insert_line, new_str)
elif command == 'undo_edit':
@@ -87,6 +87,7 @@ COPY ./code/pyproject.toml ./code/poetry.lock /openhands/code/
RUN if [ -d /openhands/code/openhands ]; then rm -rf /openhands/code/openhands; fi
COPY ./code/pyproject.toml ./code/poetry.lock /openhands/code/
COPY ./code/openhands /openhands/code/openhands
RUN chmod a+rwx /openhands/code/openhands/__init__.py
# ================================================================
# END: Build from versioned image
+1
View File
@@ -147,6 +147,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
new_event = action_from_dict(
{'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
)
# we should confirm only on agent actions
event_source = event.source if event.source else EventSource.AGENT
await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
+7 -4
View File
@@ -1,5 +1,5 @@
import json
from typing import Any, Literal
from typing import Any, Literal, Optional
import requests
from pydantic import BaseModel
@@ -10,10 +10,12 @@ from openhands.core.logger import openhands_logger as logger
class FeedbackDataModel(BaseModel):
version: str
email: str
token: str
feedback: Literal['positive', 'negative']
polarity: Literal['positive', 'negative']
feedback: Literal[
'positive', 'negative'
] # TODO: remove this, its here for backward compatibility
permissions: Literal['public', 'private']
trajectory: list[dict[str, Any]]
trajectory: Optional[list[dict[str, Any]]]
FEEDBACK_URL = 'https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_trajectory'
@@ -21,6 +23,7 @@ FEEDBACK_URL = 'https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_tra
def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]:
# Start logging
feedback.feedback = feedback.polarity
display_feedback = feedback.model_dump()
if 'trajectory' in display_feedback:
display_feedback['trajectory'] = (
+128
View File
@@ -0,0 +1,128 @@
import os
import httpx
from openhands.core.logger import openhands_logger as logger
from openhands.server.sheets_client import GoogleSheetsClient
GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
class UserVerifier:
def __init__(self) -> None:
logger.info('Initializing UserVerifier')
self.file_users: list[str] | None = None
self.sheets_client: GoogleSheetsClient | None = None
self.spreadsheet_id: str | None = None
# Initialize from environment variables
self._init_file_users()
self._init_sheets_client()
def _init_file_users(self) -> None:
"""Load users from text file if configured"""
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
if not waitlist:
logger.info('GITHUB_USER_LIST_FILE not configured')
return
if not os.path.exists(waitlist):
logger.error(f'User list file not found: {waitlist}')
raise FileNotFoundError(f'User list file not found: {waitlist}')
try:
with open(waitlist, 'r') as f:
self.file_users = [line.strip() for line in f if line.strip()]
logger.info(
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
)
except Exception as e:
logger.error(f'Error reading user list file {waitlist}: {str(e)}')
def _init_sheets_client(self) -> None:
"""Initialize Google Sheets client if configured"""
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
if not sheet_id:
logger.info('GITHUB_USERS_SHEET_ID not configured')
return
logger.info('Initializing Google Sheets integration')
self.sheets_client = GoogleSheetsClient()
self.spreadsheet_id = sheet_id
def is_active(self) -> bool:
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
def is_user_allowed(self, username: str) -> bool:
"""Check if user is allowed based on file and/or sheet configuration"""
if not self.is_active():
return True
logger.info(f'Checking if GitHub user {username} is allowed')
if self.file_users:
if username in self.file_users:
logger.info(f'User {username} found in text file allowlist')
return True
logger.debug(f'User {username} not found in text file allowlist')
if self.sheets_client and self.spreadsheet_id:
sheet_users = self.sheets_client.get_usernames(self.spreadsheet_id)
if username in sheet_users:
logger.info(f'User {username} found in Google Sheets allowlist')
return True
logger.debug(f'User {username} not found in Google Sheets allowlist')
logger.info(f'User {username} not found in any allowlist')
return False
async def authenticate_github_user(auth_token) -> bool:
user_verifier = UserVerifier()
if not user_verifier.is_active():
logger.info('No user verification sources configured - allowing all users')
return True
logger.info('Checking GitHub token')
if not auth_token:
logger.warning('No GitHub token provided')
return False
login = await get_github_user(auth_token)
if not user_verifier.is_user_allowed(login):
logger.warning(f'GitHub user {login} not in allow list')
return False
logger.info(f'GitHub user {login} authenticated')
return True
async def get_github_user(token: str) -> str:
"""Get GitHub user info from token.
Args:
token: GitHub access token
Returns:
Tuple of (login, error_message)
If successful, error_message is None
If failed, login is None and error_message contains the error
"""
logger.info('Fetching GitHub user info from token')
headers = {
'Accept': 'application/vnd.github+json',
'Authorization': f'Bearer {token}',
'X-GitHub-Api-Version': '2022-11-28',
}
async with httpx.AsyncClient() as client:
logger.debug('Making request to GitHub API')
response = await client.get('https://api.github.com/user', headers=headers)
response.raise_for_status()
user_data = response.json()
login = user_data.get('login')
logger.info(f'Successfully retrieved GitHub user: {login}')
return login
+74 -52
View File
@@ -13,6 +13,11 @@ from pathspec.patterns import GitWildMatchPattern
from openhands.security.options import SecurityAnalyzers
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.server.github import (
GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET,
authenticate_github_user,
)
from openhands.storage import get_file_store
from openhands.utils.async_utils import call_sync_from_async
@@ -52,6 +57,7 @@ from openhands.events.observation import (
NullObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.llm import bedrock
from openhands.runtime.base import Runtime
from openhands.server.auth import get_sid_from_token, sign_token
@@ -64,24 +70,6 @@ config = load_app_config()
file_store = get_file_store(config.file_store, config.file_store_path)
session_manager = SessionManager(config, file_store)
GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
# New global variable to store the user list
GITHUB_USER_LIST = None
# New function to load the user list
def load_github_user_list():
global GITHUB_USER_LIST
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
if waitlist:
with open(waitlist, 'r') as f:
GITHUB_USER_LIST = [line.strip() for line in f if line.strip()]
load_github_user_list()
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -216,7 +204,13 @@ async def attach_session(request: Request, call_next):
response = await call_next(request)
return response
# For all other methods, validate the Authorization header
github_token = request.headers.get('X-GitHub-Token')
if not await authenticate_github_user(github_token):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Not authenticated'},
)
if not request.headers.get('Authorization'):
logger.warning('Missing Authorization header')
return JSONResponse(
@@ -308,11 +302,28 @@ async def websocket_endpoint(websocket: WebSocket):
{"action": "finish", "args": {}}
```
"""
await asyncio.wait_for(websocket.accept(), 10)
# Get protocols from Sec-WebSocket-Protocol header
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
if websocket.query_params.get('token'):
token = websocket.query_params.get('token')
sid = get_sid_from_token(token, config.jwt_secret)
# The first protocol should be our real protocol (e.g. 'openhands')
# The second protocol should contain our auth token
if len(protocols) < 3:
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
real_protocol = protocols[0]
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
if not await authenticate_github_user(github_token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
if jwt_token:
sid = get_sid_from_token(jwt_token, config.jwt_secret)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
@@ -320,18 +331,21 @@ async def websocket_endpoint(websocket: WebSocket):
return
else:
sid = str(uuid.uuid4())
token = sign_token({'sid': sid}, config.jwt_secret)
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': token, 'status': 'ok'})
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
for event in session.agent_session.event_stream.get_events(
start_id=latest_event_id + 1
):
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
@@ -469,19 +483,17 @@ async def list_files(request: Request, path: str | None = None):
)
runtime: Runtime = request.state.conversation.runtime
file_list = await asyncio.create_task(
call_sync_from_async(runtime.list_files, path)
)
file_list = await call_sync_from_async(runtime.list_files, path)
if path:
file_list = [os.path.join(path, f) for f in file_list]
file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
def filter_for_gitignore(file_list, base_path):
async def filter_for_gitignore(file_list, base_path):
gitignore_path = os.path.join(base_path, '.gitignore')
try:
read_action = FileReadAction(gitignore_path)
observation = runtime.run_action(read_action)
observation = await call_sync_from_async(runtime.run_action, read_action)
spec = PathSpec.from_lines(
GitWildMatchPattern, observation.content.splitlines()
)
@@ -491,7 +503,7 @@ async def list_files(request: Request, path: str | None = None):
file_list = [entry for entry in file_list if not spec.match_file(entry)]
return file_list
file_list = filter_for_gitignore(file_list, '')
file_list = await filter_for_gitignore(file_list, '')
return file_list
@@ -634,14 +646,14 @@ async def upload_file(request: Request, files: list[UploadFile]):
@app.post('/api/submit-feedback')
async def submit_feedback(request: Request, feedback: FeedbackDataModel):
async def submit_feedback(request: Request):
"""Submit user feedback.
This function stores the provided feedback data.
To submit feedback:
```sh
curl -X POST -F "email=test@example.com" -F "token=abc" -F "feedback=positive" -F "permissions=private" -F "trajectory={}" http://localhost:3000/api/submit-feedback
curl -X POST -d '{"email": "test@example.com"}' -H "Authorization:"
```
Args:
@@ -656,8 +668,23 @@ async def submit_feedback(request: Request, feedback: FeedbackDataModel):
"""
# Assuming the storage service is already configured in the backend
# and there is a function to handle the storage.
body = await request.json()
async_stream = AsyncEventStreamWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_stream:
trajectory.append(event_to_dict(event))
feedback = FeedbackDataModel(
email=body.get('email', ''),
version=body.get('version', ''),
permissions=body.get('permissions', 'private'),
polarity=body.get('polarity', ''),
feedback=body.get('polarity', ''),
trajectory=trajectory,
)
try:
feedback_data = store_feedback(feedback)
feedback_data = await call_sync_from_async(store_feedback, feedback)
return JSONResponse(status_code=200, content=feedback_data)
except Exception as e:
logger.error(f'Error submitting feedback: {e}')
@@ -827,26 +854,21 @@ def github_callback(auth_code: AuthCode):
)
class User(BaseModel):
login: str # GitHub login handle
@app.post('/api/authenticate')
def authenticate(user: User | None = None):
global GITHUB_USER_LIST
async def authenticate(request: Request):
token = request.headers.get('X-GitHub-Token')
if not await authenticate_github_user(token):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Not authorized via GitHub waitlist'},
)
# Only check if waitlist is provided
if GITHUB_USER_LIST:
if user is None or user.login not in GITHUB_USER_LIST:
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={'error': 'User not on waitlist'},
)
return JSONResponse(
response = JSONResponse(
status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
)
return response
class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
+2 -2
View File
@@ -14,7 +14,7 @@ class LocalhostCORSMiddleware(CORSMiddleware):
def __init__(self, app: ASGIApp, **kwargs) -> None:
super().__init__(app, **kwargs)
async def is_allowed_origin(self, origin: str) -> bool:
def is_allowed_origin(self, origin: str) -> bool:
if origin:
parsed = urlparse(origin)
hostname = parsed.hostname or ''
@@ -24,7 +24,7 @@ class LocalhostCORSMiddleware(CORSMiddleware):
return True
# For missing origin or other origins, use the parent class's logic
return await super().is_allowed_origin(origin)
return super().is_allowed_origin(origin)
class NoCacheMiddleware(BaseHTTPMiddleware):
+7 -7
View File
@@ -101,7 +101,6 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
self.loop = asyncio.get_running_loop()
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
runtime_name=runtime_name,
@@ -118,15 +117,20 @@ class AgentSession:
agent_configs=agent_configs,
)
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)
if self.controller:
self.controller.agent_task = self.controller.start_step_loop()
await self.controller.agent_task # type: ignore
async def close(self):
def close(self):
"""Closes the Agent session"""
self._closed = True
def inner_close():
asyncio.run(self._close())
asyncio.get_event_loop().run_in_executor(None, inner_close)
async def _close(self):
if self._closed:
return
if self.controller is not None:
@@ -138,10 +142,6 @@ class AgentSession:
if self.security_analyzer is not None:
await self.security_analyzer.close()
if self.loop:
self.loop.stop()
self._closed = True
def _create_security_analyzer(self, security_analyzer: str | None):
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
+3 -3
View File
@@ -35,7 +35,7 @@ class SessionManager:
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
if sid in self._sessions:
asyncio.create_task(self._sessions[sid].close())
self._sessions[sid].close()
self._sessions[sid] = Session(
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
)
@@ -47,7 +47,7 @@ class SessionManager:
return self._sessions.get(sid)
async def attach_to_conversation(self, sid: str) -> Conversation | None:
if not session_exists(sid, self.file_store):
if not await session_exists(sid, self.file_store):
return None
c = Conversation(sid, file_store=self.file_store, config=self.config)
await c.connect()
@@ -87,7 +87,7 @@ class SessionManager:
for sid in session_ids_to_remove:
to_del_session: Session | None = self._sessions.pop(sid, None)
if to_del_session is not None:
await to_del_session.close()
to_del_session.close()
logger.debug(
f'Session {sid} and related resource have been removed due to inactivity.'
)
+23 -31
View File
@@ -25,8 +25,6 @@ from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.server.session.agent_session import AgentSession
from openhands.storage.files import FileStore
DEL_DELT_SEC = 60 * 60 * 5
class Session:
sid: str
@@ -49,9 +47,9 @@ class Session:
self.config = config
self.loop = asyncio.get_event_loop()
async def close(self):
def close(self):
self.is_alive = False
await self.agent_session.close()
self.agent_session.close()
async def loop_recv(self):
try:
@@ -65,18 +63,19 @@ class Session:
continue
await self.dispatch(data)
except WebSocketDisconnect:
await self.close()
logger.debug('WebSocket disconnected, sid: %s', self.sid)
logger.info('WebSocket disconnected, sid: %s', self.sid)
self.close()
except RuntimeError as e:
await self.close()
logger.exception('Error in loop_recv: %s', e)
self.close()
async def _initialize_agent(self, data: dict):
self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
)
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
)
# Extract the agent-relevant arguments from the request
args = {key: value for key, value in data.get('args', {}).items()}
@@ -138,12 +137,19 @@ class Session:
return
if event.source == EventSource.AGENT:
await self.send(event_to_dict(event))
elif event.source == EventSource.USER and isinstance(
event, CmdOutputObservation
# NOTE: ipython observations are not sent here currently
elif event.source == EventSource.ENVIRONMENT and isinstance(
event, (CmdOutputObservation, AgentStateChangedObservation)
):
await self.send(event_to_dict(event))
# feedback from the environment to agent actions is understood as agent events by the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)
elif isinstance(event, ErrorObservation):
await self.send(event_to_dict(event))
# send error events as agent events to the UI
event_dict = event_to_dict(event)
event_dict['source'] = EventSource.AGENT
await self.send(event_dict)
async def dispatch(self, data: dict):
action = data.get('action', '')
@@ -165,10 +171,12 @@ class Session:
'Model does not support image upload, change to a different model or try without an image.'
)
return
if self.agent_session.loop:
if self.loop:
asyncio.run_coroutine_threadsafe(
self._add_event(event, EventSource.USER), self.agent_session.loop
self._add_event(event, EventSource.USER), self.loop
) # type: ignore
else:
raise RuntimeError('No event loop found')
async def _add_event(self, event, event_source):
self.agent_session.event_stream.add_event(event, EventSource.USER)
@@ -192,26 +200,10 @@ class Session:
"""Sends an error message to the client."""
return await self.send({'error': True, 'message': message})
async def send_message(self, message: str) -> bool:
"""Sends a message to the client."""
return await self.send({'message': message})
async def send_status_message(self, message: str) -> bool:
"""Sends a status message to the client."""
return await self.send({'status': message})
def update_connection(self, ws: WebSocket):
self.websocket = ws
self.is_alive = True
self.last_active_ts = int(time.time())
def load_from_data(self, data: dict) -> bool:
self.last_active_ts = data.get('last_active_ts', 0)
if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
return False
self.is_alive = data.get('is_alive', False)
return True
def queue_status_message(self, message: str):
"""Queues a status message to be sent asynchronously."""
# Ensure the coroutine runs in the main event loop
+68
View File
@@ -0,0 +1,68 @@
from typing import List
from google.auth import default
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from openhands.core.logger import openhands_logger as logger
class GoogleSheetsClient:
def __init__(self):
"""Initialize Google Sheets client using workload identity.
Uses application default credentials which supports workload identity when running in GCP.
"""
logger.info('Initializing Google Sheets client with workload identity')
try:
credentials, project = default(
scopes=['https://www.googleapis.com/auth/spreadsheets.readonly']
)
logger.info(f'Successfully obtained credentials for project: {project}')
self.service = build('sheets', 'v4', credentials=credentials)
logger.info('Successfully initialized Google Sheets API service')
except Exception as e:
logger.error(f'Failed to initialize Google Sheets client: {str(e)}')
self.service = None
def get_usernames(self, spreadsheet_id: str, range_name: str = 'A:A') -> List[str]:
"""Get list of usernames from specified Google Sheet.
Args:
spreadsheet_id: The ID of the Google Sheet
range_name: The A1 notation of the range to fetch
Returns:
List of usernames from the sheet
"""
if not self.service:
logger.error('Google Sheets service not initialized')
return []
try:
logger.info(
f'Fetching usernames from sheet {spreadsheet_id}, range {range_name}'
)
result = (
self.service.spreadsheets()
.values()
.get(spreadsheetId=spreadsheet_id, range=range_name)
.execute()
)
values = result.get('values', [])
usernames = [
str(cell[0]).strip() for cell in values if cell and cell[0].strip()
]
logger.info(
f'Successfully fetched {len(usernames)} usernames from Google Sheet'
)
return usernames
except HttpError as err:
logger.error(f'Error accessing Google Sheet {spreadsheet_id}: {err}')
return []
except Exception as e:
logger.error(
f'Unexpected error accessing Google Sheet {spreadsheet_id}: {str(e)}'
)
return []
Generated
+20 -2
View File
@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]]
name = "aenum"
@@ -2319,6 +2319,24 @@ files = [
google-auth = "*"
httplib2 = ">=0.19.0"
[[package]]
name = "google-auth-oauthlib"
version = "1.2.1"
description = "Google Authentication Library"
optional = false
python-versions = ">=3.6"
files = [
{file = "google_auth_oauthlib-1.2.1-py2.py3-none-any.whl", hash = "sha256:2d58a27262d55aa1b87678c3ba7142a080098cbc2024f903c62355deb235d91f"},
{file = "google_auth_oauthlib-1.2.1.tar.gz", hash = "sha256:afd0cad092a2eaa53cd8e8298557d6de1034c6cb4a740500b5357b648af97263"},
]
[package.dependencies]
google-auth = ">=2.15.0"
requests-oauthlib = ">=0.7.0"
[package.extras]
tool = ["click (>=6.0.0)"]
[[package]]
name = "google-cloud-aiplatform"
version = "1.70.0"
@@ -10109,4 +10127,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
content-hash = "2b268ef696ace0d8170276407dbdeb414134477839ebe4b7ecf29b1a1fe2cef3"
content-hash = "2a4f90bb5c7f7d82160f57d71af7e81c7acef69426d0e1e46e1da09972a6215f"
+4 -3
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openhands-ai"
version = "0.11.0"
version = "0.12.0"
description = "OpenHands: Code Less, Make More"
authors = ["OpenHands"]
license = "MIT"
@@ -16,6 +16,9 @@ datasets = "*"
pandas = "*"
litellm = "^1.51.1"
google-generativeai = "*" # To use litellm with Gemini Pro API
google-api-python-client = "*" # For Google Sheets API
google-auth-httplib2 = "*" # For Google Sheets authentication
google-auth-oauthlib = "*" # For Google Sheets OAuth
termcolor = "*"
seaborn = "*"
docker = "*"
@@ -89,7 +92,6 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -120,7 +122,6 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"
+82
View File
@@ -232,3 +232,85 @@ def test_ipython_package_install(temp_dir, runtime_cls, run_as_openhands):
)
_close_test_runtime(runtime)
def test_ipython_file_editor_permissions_as_openhands(temp_dir, runtime_cls):
"""Test file editor permission behavior when running as different users."""
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands=True)
sandbox_dir = _get_sandbox_folder(runtime)
# Create a file owned by root with restricted permissions
action = CmdRunAction(
command='sudo touch /root/test.txt && sudo chmod 600 /root/test.txt'
)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
# Try to view the file as openhands user - should fail with permission denied
test_code = "print(file_editor(command='view', path='/root/test.txt'))"
action = IPythonRunCellAction(code=test_code)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert 'Permission denied' in obs.content
# Try to edit the file as openhands user - should fail with permission denied
test_code = "print(file_editor(command='str_replace', path='/root/test.txt', old_str='', new_str='test'))"
action = IPythonRunCellAction(code=test_code)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert 'Permission denied' in obs.content
# Try to create a file in root directory - should fail with permission denied
test_code = (
"print(file_editor(command='create', path='/root/new.txt', file_text='test'))"
)
action = IPythonRunCellAction(code=test_code)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert 'Permission denied' in obs.content
# Try to use file editor in openhands sandbox directory - should work
test_code = f"""
# Create file
print(file_editor(command='create', path='{sandbox_dir}/test.txt', file_text='Line 1\\nLine 2\\nLine 3'))
# View file
print(file_editor(command='view', path='{sandbox_dir}/test.txt'))
# Edit file
print(file_editor(command='str_replace', path='{sandbox_dir}/test.txt', old_str='Line 2', new_str='New Line 2'))
# Undo edit
print(file_editor(command='undo_edit', path='{sandbox_dir}/test.txt'))
"""
action = IPythonRunCellAction(code=test_code)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert 'File created successfully' in obs.content
assert 'Line 1' in obs.content
assert 'Line 2' in obs.content
assert 'Line 3' in obs.content
assert 'New Line 2' in obs.content
assert 'Last edit to' in obs.content
assert 'undone successfully' in obs.content
# Clean up
action = CmdRunAction(command=f'rm -f {sandbox_dir}/test.txt')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
action = CmdRunAction(command='sudo rm -f /root/test.txt')
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
_close_test_runtime(runtime)
+3 -1
View File
@@ -207,7 +207,9 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
'Non fatal error here to trigger loop'
)
non_fatal_error_obs._cause = event.id
await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
await event_stream.async_add_event(
non_fatal_error_obs, EventSource.ENVIRONMENT
)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
runtime.event_stream = event_stream
+37 -37
View File
@@ -80,7 +80,7 @@ class TestStuckDetector:
code=code_snippet,
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
def _impl_unterminated_string_error_events(
self, event_stream: EventStream, random_line: bool, incidents: int = 4
@@ -96,7 +96,7 @@ class TestStuckDetector:
code=code_snippet,
)
ipython_observation._cause = ipython_action._id
event_stream.add_event(ipython_observation, EventSource.USER)
event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
def test_history_too_short(
self, stuck_detector: StuckDetector, event_stream: EventStream
@@ -106,7 +106,7 @@ class TestStuckDetector:
observation = NullObservation(content='')
observation._cause = message_action.id
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(observation, EventSource.USER)
event_stream.add_event(observation, EventSource.ENVIRONMENT)
cmd_action = CmdRunAction(command='ls')
event_stream.add_event(cmd_action, EventSource.AGENT)
@@ -114,7 +114,7 @@ class TestStuckDetector:
command_id=1, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation._cause = cmd_action._id
event_stream.add_event(cmd_observation, EventSource.USER)
event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT)
# stuck_detector.state.history.set_event_stream(event_stream)
@@ -131,7 +131,7 @@ class TestStuckDetector:
# 2 events
event_stream.add_event(hello_action, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
@@ -139,7 +139,7 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_1._id
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.USER)
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
# 4 events
cmd_action_2 = CmdRunAction(command='ls')
@@ -148,13 +148,13 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_2._id
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.USER)
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
# 6 events
# random user message just because we can
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
# 8 events
assert stuck_detector.is_stuck() is False
@@ -166,7 +166,7 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_3._id
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.USER)
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
# 10 events
assert len(collect_events(event_stream)) == 10
@@ -191,7 +191,7 @@ class TestStuckDetector:
content='', command='ls', command_id=cmd_action_4._id
)
cmd_observation_4._cause = cmd_action_4._id
event_stream.add_event(cmd_observation_4, EventSource.USER)
event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT)
# 12 events
assert len(collect_events(event_stream)) == 12
@@ -223,14 +223,14 @@ class TestStuckDetector:
hello_observation = NullObservation(content='')
event_stream.add_event(hello_action, EventSource.USER)
hello_observation._cause = hello_action._id
event_stream.add_event(hello_observation, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
# 2 events
cmd_action_1 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
error_observation_1 = ErrorObservation(content='Command not found')
error_observation_1._cause = cmd_action_1._id
event_stream.add_event(error_observation_1, EventSource.USER)
event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT)
# 4 events
cmd_action_2 = CmdRunAction(command='invalid_command')
@@ -239,26 +239,26 @@ class TestStuckDetector:
content='Command still not found or another error'
)
error_observation_2._cause = cmd_action_2._id
event_stream.add_event(error_observation_2, EventSource.USER)
event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT)
# 6 events
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
# 8 events
cmd_action_3 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
error_observation_3 = ErrorObservation(content='Different error')
error_observation_3._cause = cmd_action_3._id
event_stream.add_event(error_observation_3, EventSource.USER)
event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT)
# 10 events
cmd_action_4 = CmdRunAction(command='invalid_command')
event_stream.add_event(cmd_action_4, EventSource.AGENT)
error_observation_4 = ErrorObservation(content='Command not found')
error_observation_4._cause = cmd_action_4._id
event_stream.add_event(error_observation_4, EventSource.USER)
event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT)
# 12 events
with patch('logging.Logger.warning') as mock_warning:
@@ -366,7 +366,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_1._cause = ipython_action_1._id
event_stream.add_event(ipython_observation_1, EventSource.USER)
event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT)
ipython_action_2 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_2, EventSource.AGENT)
@@ -375,7 +375,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_2._cause = ipython_action_2._id
event_stream.add_event(ipython_observation_2, EventSource.USER)
event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT)
ipython_action_3 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_3, EventSource.AGENT)
@@ -384,7 +384,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_3._cause = ipython_action_3._id
event_stream.add_event(ipython_observation_3, EventSource.USER)
event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT)
ipython_action_4 = IPythonRunCellAction(code='print("hello')
event_stream.add_event(ipython_action_4, EventSource.AGENT)
@@ -393,7 +393,7 @@ class TestStuckDetector:
code='print("hello',
)
ipython_observation_4._cause = ipython_action_4._id
event_stream.add_event(ipython_observation_4, EventSource.USER)
event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is False
@@ -406,7 +406,7 @@ class TestStuckDetector:
message_action._source = EventSource.USER
event_stream.add_event(message_action, EventSource.USER)
message_observation = NullObservation(content='')
event_stream.add_event(message_observation, EventSource.USER)
event_stream.add_event(message_observation, EventSource.ENVIRONMENT)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
@@ -414,7 +414,7 @@ class TestStuckDetector:
command_id=1, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.USER)
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
read_action_1 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_1, EventSource.AGENT)
@@ -422,7 +422,7 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_1._cause = read_action_1._id
event_stream.add_event(read_observation_1, EventSource.USER)
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
cmd_action_2 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
@@ -430,7 +430,7 @@ class TestStuckDetector:
command_id=2, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.USER)
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
read_action_2 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_2, EventSource.AGENT)
@@ -438,12 +438,12 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_2._cause = read_action_2._id
event_stream.add_event(read_observation_2, EventSource.USER)
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
# one more message to break the pattern
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
cmd_action_3 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
@@ -451,7 +451,7 @@ class TestStuckDetector:
command_id=3, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.USER)
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
read_action_3 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_3, EventSource.AGENT)
@@ -459,7 +459,7 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_3._cause = read_action_3._id
event_stream.add_event(read_observation_3, EventSource.USER)
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
@@ -475,7 +475,7 @@ class TestStuckDetector:
event_stream.add_event(hello_action, EventSource.USER)
hello_observation = NullObservation(content='')
hello_observation._cause = hello_action._id
event_stream.add_event(hello_observation, EventSource.USER)
event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
cmd_action_1 = CmdRunAction(command='ls')
event_stream.add_event(cmd_action_1, EventSource.AGENT)
@@ -483,7 +483,7 @@ class TestStuckDetector:
command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
)
cmd_observation_1._cause = cmd_action_1._id
event_stream.add_event(cmd_observation_1, EventSource.USER)
event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
read_action_1 = FileReadAction(path='file1.txt')
event_stream.add_event(read_action_1, EventSource.AGENT)
@@ -491,7 +491,7 @@ class TestStuckDetector:
content='File content', path='file1.txt'
)
read_observation_1._cause = read_action_1._id
event_stream.add_event(read_observation_1, EventSource.USER)
event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
cmd_action_2 = CmdRunAction(command='pwd')
event_stream.add_event(cmd_action_2, EventSource.AGENT)
@@ -499,7 +499,7 @@ class TestStuckDetector:
command_id=2, command='pwd', content='/home/user'
)
cmd_observation_2._cause = cmd_action_2._id
event_stream.add_event(cmd_observation_2, EventSource.USER)
event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
read_action_2 = FileReadAction(path='file2.txt')
event_stream.add_event(read_action_2, EventSource.AGENT)
@@ -507,11 +507,11 @@ class TestStuckDetector:
content='Another file content', path='file2.txt'
)
read_observation_2._cause = read_action_2._id
event_stream.add_event(read_observation_2, EventSource.USER)
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
message_null_observation = NullObservation(content='')
event_stream.add_event(message_action, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.USER)
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
cmd_action_3 = CmdRunAction(command='pwd')
event_stream.add_event(cmd_action_3, EventSource.AGENT)
@@ -519,7 +519,7 @@ class TestStuckDetector:
command_id=cmd_action_3.id, command='pwd', content='/home/user'
)
cmd_observation_3._cause = cmd_action_3._id
event_stream.add_event(cmd_observation_3, EventSource.USER)
event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
read_action_3 = FileReadAction(path='file2.txt')
event_stream.add_event(read_action_3, EventSource.AGENT)
@@ -527,7 +527,7 @@ class TestStuckDetector:
content='Another file content', path='file2.txt'
)
read_observation_3._cause = read_action_3._id
event_stream.add_event(read_observation_3, EventSource.USER)
event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
assert stuck_detector.is_stuck() is False
@@ -572,7 +572,7 @@ class TestStuckDetector:
exit_code=0,
)
cmd_output_observation._cause = cmd_kill_action._id
event_stream.add_event(cmd_output_observation, EventSource.USER)
event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT)
message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
event_stream.add_event(message_action_7, EventSource.AGENT)
+3
View File
@@ -50,6 +50,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
'max_output_tokens': 2000,
}
llm = LLM(default_config)
llm.init_model_info()
assert llm.config.max_input_tokens == 8000
assert llm.config.max_output_tokens == 2000
@@ -58,6 +59,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
def test_llm_init_without_model_info(mock_get_model_info, default_config):
mock_get_model_info.side_effect = Exception('Model info not available')
llm = LLM(default_config)
llm.init_model_info()
assert llm.config.max_input_tokens == 4096
assert llm.config.max_output_tokens == 4096
@@ -108,6 +110,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
'max_output_tokens': 1500,
}
llm = LLM(default_config)
llm.init_model_info()
assert llm.config.max_input_tokens == 7000
assert llm.config.max_output_tokens == 1500
mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
+1 -1
View File
@@ -88,7 +88,7 @@ def _create_observation_event(observation: str) -> Event:
event = Event()
event._id = -1
event._timestamp = datetime.now(timezone.utc).isoformat()
event._source = EventSource.USER
event._source = EventSource.ENVIRONMENT
event.observation = observation
return event
+2 -2
View File
@@ -155,7 +155,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
command='ls -l',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
message_action_2 = MessageAction("Now, let's create a new directory.")
mock_event_stream.add_event(message_action_2, EventSource.AGENT)
@@ -169,7 +169,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
command='mkdir new_directory',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
codeact_agent.reset()
messages = codeact_agent._get_messages(